diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4ce04787b..c23ae227a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -13,10 +13,10 @@ repos: args: [--pytest-test-first] exclude: ^tests/unit/helpers/ - id: check-docstring-first -# - repo: https://github.com/psf/black -# rev: 23.3.0 -# hooks: -# - id: black +- repo: https://github.com/psf/black + rev: 23.3.0 + hooks: + - id: black - repo: https://github.com/charliermarsh/ruff-pre-commit rev: v0.0.272 hooks: diff --git a/CHANGELOG.md b/CHANGELOG.md index ad3d40228..e59db1a4b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,10 @@ # HDMF Changelog +## HDMF 3.7.0 (Upcoming) + +## Minor improvements +- Set code style guide to `black`. @rly [#860](https://github.com/hdmf-dev/hdmf/pull/860) + ## HMDF 3.6.1 (May 18, 2023) ### Bug fixes diff --git a/docs/CONTRIBUTING.rst b/docs/CONTRIBUTING.rst index 052fed7b7..981de28da 100644 --- a/docs/CONTRIBUTING.rst +++ b/docs/CONTRIBUTING.rst @@ -90,21 +90,25 @@ Style Guides Python Code Style Guide ^^^^^^^^^^^^^^^^^^^^^^^ -Before you create a Pull Request, make sure you are following the PEP8_ style guide. . -To check whether your code conforms to the HDMF style guide, simply run the ruff_ tool in the project's root -directory. ``ruff`` will also sort imports automatically and check against additional code style rules. +Before you create a Pull Request, make sure you are following the Black_ preview style guide, which follows PEP8. +We also break from the Black format by configuring it with a max line length of 120. +To check whether your code conforms to the Black_ preview style guide, simply run the ``black`` tool in the +project's root directory with the ``--check`` argument. You can also run the ``black`` tool without +the ``--check`` argument to have black automatically format the codebase to comply with the style guide. -We also use ``ruff`` to sort python imports automatically and double-check that the codebase +We also use the ruff_ tool to sort python imports automatically and double-check that the codebase conforms to PEP8 standards, while using the codespell_ tool to check spelling. -``ruff`` and ``codespell`` are installed when you follow the developer installation instructions. See +The ``black``, ``ruff``, and ``codespell`` tools are installed when you follow the developer installation instructions. See :ref:`install_developers`. +.. _Black: https://black.readthedocs.io/en/stable/ .. _ruff: https://beta.ruff.rs/docs/ .. _codespell: https://github.com/codespell-project/codespell .. code:: + $ black . $ ruff check . $ codespell diff --git a/pyproject.toml b/pyproject.toml index 9b7fac7af..dc979bc9d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -85,11 +85,11 @@ exclude_lines = [ [tool.setuptools_scm] -# [tool.black] -# line-length = 120 -# preview = true -# exclude = ".git|.mypy_cache|.tox|.venv|venv|.ipynb_checkpoints|_build/|dist/|__pypackages__|.ipynb" -# force-exclude = "src/hdmf/common/hdmf-common-schema|docs/gallery" +[tool.black] +line-length = 120 +preview = true +exclude = ".git|.mypy_cache|.tox|.venv|venv|.ipynb_checkpoints|_build/|dist/|__pypackages__|.ipynb" +force-exclude = "src/hdmf/common/hdmf-common-schema|docs/gallery" [tool.ruff] select = ["E", "F", "T100", "T201", "T203"] diff --git a/src/hdmf/array.py b/src/hdmf/array.py index a684572e4..597a4c82d 100644 --- a/src/hdmf/array.py +++ b/src/hdmf/array.py @@ -1,13 +1,12 @@ -from abc import abstractmethod, ABCMeta +from abc import ABCMeta, abstractmethod import numpy as np class Array: - def __init__(self, data): self.__data = data - if hasattr(data, 'dtype'): + if hasattr(data, "dtype"): self.dtype = data.dtype else: tmp = data @@ -41,7 +40,10 @@ def __getitem__(self, arg): idx.append(i) return np.fromiter((self.__getidx__(x) for x in idx), dtype=self.dtype) elif isinstance(arg, slice): - return np.fromiter((self.__getidx__(x) for x in self.__sliceiter(arg)), dtype=self.dtype) + return np.fromiter( + (self.__getidx__(x) for x in self.__sliceiter(arg)), + dtype=self.dtype, + ) elif isinstance(arg, tuple): return (self.__getidx__(arg[0]), self.__getidx__(arg[1])) else: @@ -49,9 +51,9 @@ def __getitem__(self, arg): class AbstractSortedArray(Array, metaclass=ABCMeta): - ''' + """ An abstract class for representing sorted array - ''' + """ @abstractmethod def find_point(self, val): @@ -160,11 +162,11 @@ def __ne__(self, other): class SortedArray(AbstractSortedArray): - ''' + """ A class for wrapping sorted arrays. This class overrides <,>,<=,>=,==, and != to leverage the sorted content for efficiency. - ''' + """ def __init__(self, array): super().__init__(array) @@ -174,7 +176,6 @@ def find_point(self, val): class LinSpace(SortedArray): - def __init__(self, start, stop, step): self.start = start self.stop = stop diff --git a/src/hdmf/backends/hdf5/__init__.py b/src/hdmf/backends/hdf5/__init__.py index 6abfc8c85..989e09396 100644 --- a/src/hdmf/backends/hdf5/__init__.py +++ b/src/hdmf/backends/hdf5/__init__.py @@ -1,3 +1,3 @@ from . import h5_utils, h5tools -from .h5_utils import H5RegionSlicer, H5DataIO -from .h5tools import HDF5IO, H5SpecWriter, H5SpecReader +from .h5_utils import H5DataIO, H5RegionSlicer +from .h5tools import HDF5IO, H5SpecReader, H5SpecWriter diff --git a/src/hdmf/backends/hdf5/h5_utils.py b/src/hdmf/backends/hdf5/h5_utils.py index b39d540a2..56be397cf 100644 --- a/src/hdmf/backends/hdf5/h5_utils.py +++ b/src/hdmf/backends/hdf5/h5_utils.py @@ -3,25 +3,26 @@ e.g., for wrapping HDF5 datasets on read, wrapping arrays for configuring write, or writing the spec among others""" -from collections import deque +import json +import logging +import os +import warnings from abc import ABCMeta, abstractmethod +from collections import deque from collections.abc import Iterable from copy import copy -from h5py import Group, Dataset, RegionReference, Reference, special_dtype -from h5py import filters as h5py_filters -import json import numpy as np -import warnings -import os -import logging +from h5py import Dataset, Group, Reference, RegionReference +from h5py import filters as h5py_filters +from h5py import special_dtype from ...array import Array -from ...data_utils import DataIO, AbstractDataChunkIterator -from ...query import HDMFDataset, ReferenceResolver, ContainerResolver, BuilderResolver +from ...data_utils import AbstractDataChunkIterator, DataIO +from ...query import BuilderResolver, ContainerResolver, HDMFDataset, ReferenceResolver from ...region import RegionSlicer -from ...spec import SpecWriter, SpecReader -from ...utils import docval, getargs, popargs, get_docval +from ...spec import SpecReader, SpecWriter +from ...utils import docval, get_docval, getargs, popargs class HDF5IODataChunkIteratorQueue(deque): @@ -31,8 +32,9 @@ class HDF5IODataChunkIteratorQueue(deque): Each queue element must be a tuple of two elements: 1) the dataset to write to and 2) the AbstractDataChunkIterator with the data """ + def __init__(self): - self.logger = logging.getLogger('%s.%s' % (self.__class__.__module__, self.__class__.__qualname__)) + self.logger = logging.getLogger("%s.%s" % (self.__class__.__module__, self.__class__.__qualname__)) super().__init__() @classmethod @@ -85,10 +87,12 @@ def append(self, dataset, data): class H5Dataset(HDMFDataset): - @docval({'name': 'dataset', 'type': (Dataset, Array), 'doc': 'the HDF5 file lazily evaluate'}, - {'name': 'io', 'type': 'HDF5IO', 'doc': 'the IO object that was used to read the underlying dataset'}) + @docval( + {"name": "dataset", "type": (Dataset, Array), "doc": "the HDF5 file lazily evaluate"}, + {"name": "io", "type": "HDF5IO", "doc": "the IO object that was used to read the underlying dataset"}, + ) def __init__(self, **kwargs): - self.__io = popargs('io', kwargs) + self.__io = popargs("io", kwargs) super().__init__(**kwargs) @property @@ -126,12 +130,12 @@ def invert(self): Return an object that defers reference resolution but in the opposite direction. """ - if not hasattr(self, '__inverted'): + if not hasattr(self, "__inverted"): cls = self.get_inverse_class() docval = get_docval(cls.__init__) kwargs = dict() for arg in docval: - kwargs[arg['name']] = getattr(self, arg['name']) + kwargs[arg["name"]] = getattr(self, arg["name"]) self.__inverted = cls(**kwargs) return self.__inverted @@ -173,13 +177,25 @@ def get_object(self, h5obj): class AbstractH5TableDataset(DatasetOfReferences): - - @docval({'name': 'dataset', 'type': (Dataset, Array), 'doc': 'the HDF5 file lazily evaluate'}, - {'name': 'io', 'type': 'HDF5IO', 'doc': 'the IO object that was used to read the underlying dataset'}, - {'name': 'types', 'type': (list, tuple), - 'doc': 'the IO object that was used to read the underlying dataset'}) + @docval( + { + "name": "dataset", + "type": (Dataset, Array), + "doc": "the HDF5 file lazily evaluate", + }, + { + "name": "io", + "type": "HDF5IO", + "doc": "the IO object that was used to read the underlying dataset", + }, + { + "name": "types", + "type": (list, tuple), + "doc": "the IO object that was used to read the underlying dataset", + }, + ) def __init__(self, **kwargs): - types = popargs('types', kwargs) + types = popargs("types", kwargs) super().__init__(**kwargs) self.__refgetters = dict() for i, t in enumerate(types): @@ -197,18 +213,18 @@ def __init__(self, **kwargs): for i in range(len(self.dataset.dtype)): sub = self.dataset.dtype[i] if sub.metadata: - if 'vlen' in sub.metadata: - t = sub.metadata['vlen'] + if "vlen" in sub.metadata: + t = sub.metadata["vlen"] if t is str: - tmp.append('utf') + tmp.append("utf") elif t is bytes: - tmp.append('ascii') - elif 'ref' in sub.metadata: - t = sub.metadata['ref'] + tmp.append("ascii") + elif "ref" in sub.metadata: + t = sub.metadata["ref"] if t is Reference: - tmp.append('object') + tmp.append("object") elif t is RegionReference: - tmp.append('region') + tmp.append("region") else: tmp.append(sub.type.__name__) self.__dtype = tmp @@ -239,14 +255,14 @@ def _get_utf(self, string): """ Decode a dataset element to unicode """ - return string.decode('utf-8') if isinstance(string, bytes) else string + return string.decode("utf-8") if isinstance(string, bytes) else string def __get_regref(self, ref): obj = self._get_ref(ref) return obj[ref] def resolve(self, manager): - return self[0:len(self)] + return self[0 : len(self)] def __iter__(self): for i in range(len(self)): @@ -254,7 +270,6 @@ def __iter__(self): class AbstractH5ReferenceDataset(DatasetOfReferences): - def __getitem__(self, arg): ref = super().__getitem__(arg) if isinstance(ref, np.ndarray): @@ -264,11 +279,10 @@ def __getitem__(self, arg): @property def dtype(self): - return 'object' + return "object" class AbstractH5RegionDataset(AbstractH5ReferenceDataset): - def __getitem__(self, arg): obj = super().__getitem__(arg) ref = self.dataset[arg] @@ -276,7 +290,7 @@ def __getitem__(self, arg): @property def dtype(self): - return 'region' + return "region" class ContainerH5TableDataset(ContainerResolverMixin, AbstractH5TableDataset): @@ -346,19 +360,18 @@ def get_inverse_class(cls): class H5SpecWriter(SpecWriter): - __str_type = special_dtype(vlen=str) - @docval({'name': 'group', 'type': Group, 'doc': 'the HDF5 file to write specs to'}) + @docval({"name": "group", "type": Group, "doc": "the HDF5 file to write specs to"}) def __init__(self, **kwargs): - self.__group = getargs('group', kwargs) + self.__group = getargs("group", kwargs) @staticmethod def stringify(spec): - ''' + """ Converts a spec into a JSON string to write to a dataset - ''' - return json.dumps(spec, separators=(',', ':')) + """ + return json.dumps(spec, separators=(",", ":")) def __write(self, d, name): data = self.stringify(d) @@ -370,15 +383,15 @@ def write_spec(self, spec, path): return self.__write(spec, path) def write_namespace(self, namespace, path): - return self.__write({'namespaces': [namespace]}, path) + return self.__write({"namespaces": [namespace]}, path) class H5SpecReader(SpecReader): """Class that reads cached JSON-formatted namespace and spec data from an HDF5 group.""" - @docval({'name': 'group', 'type': Group, 'doc': 'the HDF5 group to read specs from'}) + @docval({"name": "group", "type": Group, "doc": "the HDF5 group to read specs from"}) def __init__(self, **kwargs): - self.__group = popargs('group', kwargs) + self.__group = popargs("group", kwargs) source = "%s:%s" % (os.path.abspath(self.__group.file.name), self.__group.name) super().__init__(source=source) self.__cache = None @@ -389,7 +402,7 @@ def __read(self, path): s = s[0] if isinstance(s, bytes): - s = s.decode('UTF-8') + s = s.decode("UTF-8") d = json.loads(s) return d @@ -400,17 +413,18 @@ def read_spec(self, spec_path): def read_namespace(self, ns_path): if self.__cache is None: self.__cache = self.__read(ns_path) - ret = self.__cache['namespaces'] + ret = self.__cache["namespaces"] return ret class H5RegionSlicer(RegionSlicer): - - @docval({'name': 'dataset', 'type': (Dataset, H5Dataset), 'doc': 'the HDF5 dataset to slice'}, - {'name': 'region', 'type': RegionReference, 'doc': 'the region reference to use to slice'}) + @docval( + {"name": "dataset", "type": (Dataset, H5Dataset), "doc": "the HDF5 dataset to slice"}, + {"name": "region", "type": RegionReference, "doc": "the region reference to use to slice"}, + ) def __init__(self, **kwargs): - self.__dataset = getargs('dataset', kwargs) - self.__regref = getargs('region', kwargs) + self.__dataset = getargs("dataset", kwargs) + self.__regref = getargs("region", kwargs) self.__len = self.__dataset.regionref.selection(self.__regref)[0] self.__region = None @@ -432,105 +446,152 @@ class H5DataIO(DataIO): for data arrays. """ - @docval({'name': 'data', - 'type': (np.ndarray, list, tuple, Dataset, Iterable), - 'doc': 'the data to be written. NOTE: If an h5py.Dataset is used, all other settings but link_data' + - ' will be ignored as the dataset will either be linked to or copied as is in H5DataIO.', - 'default': None}, - {'name': 'maxshape', - 'type': tuple, - 'doc': 'Dataset will be resizable up to this shape (Tuple). Automatically enables chunking.' + - 'Use None for the axes you want to be unlimited.', - 'default': None}, - {'name': 'chunks', - 'type': (bool, tuple), - 'doc': 'Chunk shape or True to enable auto-chunking', - 'default': None}, - {'name': 'compression', - 'type': (str, bool, int), - 'doc': 'Compression strategy. If a bool is given, then gzip compression will be used by default.' + - 'http://docs.h5py.org/en/latest/high/dataset.html#dataset-compression', - 'default': None}, - {'name': 'compression_opts', - 'type': (int, tuple), - 'doc': 'Parameter for compression filter', - 'default': None}, - {'name': 'fillvalue', - 'type': None, - 'doc': 'Value to be returned when reading uninitialized parts of the dataset', - 'default': None}, - {'name': 'shuffle', - 'type': bool, - 'doc': 'Enable shuffle I/O filter. http://docs.h5py.org/en/latest/high/dataset.html#dataset-shuffle', - 'default': None}, - {'name': 'fletcher32', - 'type': bool, - 'doc': 'Enable fletcher32 checksum. http://docs.h5py.org/en/latest/high/dataset.html#dataset-fletcher32', - 'default': None}, - {'name': 'link_data', - 'type': bool, - 'doc': 'If data is an h5py.Dataset should it be linked to or copied. NOTE: This parameter is only ' + - 'allowed if data is an h5py.Dataset', - 'default': False}, - {'name': 'allow_plugin_filters', - 'type': bool, - 'doc': 'Enable passing dynamically loaded filters as compression parameter', - 'default': False}, - {'name': 'shape', - 'type': tuple, - 'doc': 'the shape of the new dataset, used only if data is None', - 'default': None}, - {'name': 'dtype', - 'type': (str, type, np.dtype), - 'doc': 'the data type of the new dataset, used only if data is None', - 'default': None} - ) + @docval( + { + "name": "data", + "type": (np.ndarray, list, tuple, Dataset, Iterable), + "doc": ( + "the data to be written. NOTE: If an h5py.Dataset is used, all other settings but link_data " + "will be ignored as the dataset will either be linked to or copied as is in H5DataIO." + ), + "default": None, + }, + { + "name": "maxshape", + "type": tuple, + "doc": ( + "Dataset will be resizable up to this shape (Tuple). Automatically enables chunking. " + "Use None for the axes you want to be unlimited." + ), + "default": None, + }, + { + "name": "chunks", + "type": (bool, tuple), + "doc": "Chunk shape or True to enable auto-chunking", + "default": None, + }, + { + "name": "compression", + "type": (str, bool, int), + "doc": ( + "Compression strategy. If a bool is given, then gzip compression will be used by default. " + "http://docs.h5py.org/en/latest/high/dataset.html#dataset-compression" + ), + "default": None, + }, + { + "name": "compression_opts", + "type": (int, tuple), + "doc": "Parameter for compression filter", + "default": None, + }, + { + "name": "fillvalue", + "type": None, + "doc": "Value to be returned when reading uninitialized parts of the dataset", + "default": None, + }, + { + "name": "shuffle", + "type": bool, + "doc": "Enable shuffle I/O filter. http://docs.h5py.org/en/latest/high/dataset.html#dataset-shuffle", + "default": None, + }, + { + "name": "fletcher32", + "type": bool, + "doc": "Enable fletcher32 checksum. http://docs.h5py.org/en/latest/high/dataset.html#dataset-fletcher32", + "default": None, + }, + { + "name": "link_data", + "type": bool, + "doc": ( + "If data is an h5py.Dataset should it be linked to or copied. NOTE: This parameter is only " + "allowed if data is an h5py.Dataset" + ), + "default": False, + }, + { + "name": "allow_plugin_filters", + "type": bool, + "doc": "Enable passing dynamically loaded filters as compression parameter", + "default": False, + }, + { + "name": "shape", + "type": tuple, + "doc": "the shape of the new dataset, used only if data is None", + "default": None, + }, + { + "name": "dtype", + "type": (str, type, np.dtype), + "doc": "the data type of the new dataset, used only if data is None", + "default": None, + }, + ) def __init__(self, **kwargs): # Get the list of I/O options that user has passed in - ioarg_names = [name for name in kwargs.keys() if name not in ['data', 'link_data', 'allow_plugin_filters', - 'dtype', 'shape']] + ioarg_names = [ + name + for name in kwargs.keys() + if name + not in [ + "data", + "link_data", + "allow_plugin_filters", + "dtype", + "shape", + ] + ] # Remove the ioargs from kwargs ioarg_values = [popargs(argname, kwargs) for argname in ioarg_names] # Consume link_data parameter - self.__link_data = popargs('link_data', kwargs) + self.__link_data = popargs("link_data", kwargs) # Consume allow_plugin_filters parameter - self.__allow_plugin_filters = popargs('allow_plugin_filters', kwargs) + self.__allow_plugin_filters = popargs("allow_plugin_filters", kwargs) # Check for possible collision with other parameters - if not isinstance(getargs('data', kwargs), Dataset) and self.__link_data: + if not isinstance(getargs("data", kwargs), Dataset) and self.__link_data: self.__link_data = False - warnings.warn('link_data parameter in H5DataIO will be ignored') + warnings.warn("link_data parameter in H5DataIO will be ignored") # Call the super constructor and consume the data parameter super().__init__(**kwargs) # Construct the dict with the io args, ignoring all options that were set to None self.__iosettings = {k: v for k, v in zip(ioarg_names, ioarg_values) if v is not None} if self.data is None: - self.__iosettings['dtype'] = self.dtype - self.__iosettings['shape'] = self.shape + self.__iosettings["dtype"] = self.dtype + self.__iosettings["shape"] = self.shape # Set io_properties for DataChunkIterators if isinstance(self.data, AbstractDataChunkIterator): # Define the chunking options if the user has not set them explicitly. - if 'chunks' not in self.__iosettings and self.data.recommended_chunk_shape() is not None: - self.__iosettings['chunks'] = self.data.recommended_chunk_shape() + if "chunks" not in self.__iosettings and self.data.recommended_chunk_shape() is not None: + self.__iosettings["chunks"] = self.data.recommended_chunk_shape() # Define the maxshape of the data if not provided by the user - if 'maxshape' not in self.__iosettings: - self.__iosettings['maxshape'] = self.data.maxshape + if "maxshape" not in self.__iosettings: + self.__iosettings["maxshape"] = self.data.maxshape # Make default settings when compression set to bool (True/False) - if isinstance(self.__iosettings.get('compression', None), bool): - if self.__iosettings['compression']: - self.__iosettings['compression'] = 'gzip' + if isinstance(self.__iosettings.get("compression", None), bool): + if self.__iosettings["compression"]: + self.__iosettings["compression"] = "gzip" else: - self.__iosettings.pop('compression', None) - if 'compression_opts' in self.__iosettings: - warnings.warn('Compression disabled by compression=False setting. ' + - 'compression_opts parameter will, therefore, be ignored.') - self.__iosettings.pop('compression_opts', None) + self.__iosettings.pop("compression", None) + if "compression_opts" in self.__iosettings: + warnings.warn( + "Compression disabled by compression=False setting. " + "compression_opts parameter will, therefore, be ignored." + ) + self.__iosettings.pop("compression_opts", None) # Validate the compression options used self._check_compression_options() # Confirm that the compressor is supported by h5py - if not self.filter_available(self.__iosettings.get('compression', None), - self.__allow_plugin_filters): - msg = "%s compression may not be supported by this version of h5py." % str(self.__iosettings['compression']) + if not self.filter_available( + self.__iosettings.get("compression", None), + self.__allow_plugin_filters, + ): + msg = "%s compression may not be supported by this version of h5py." % str(self.__iosettings["compression"]) if not self.__allow_plugin_filters: msg += " Set `allow_plugin_filters=True` to enable the use of dynamically-loaded plugin filters." raise ValueError(msg) @@ -556,7 +617,7 @@ def get_io_params(self): Returns a dict with the I/O parameters specified in this DataIO. """ ret = dict(self.__iosettings) - ret['link_data'] = self.__link_data + ret["link_data"] = self.__link_data return ret def _check_compression_options(self): @@ -566,35 +627,43 @@ def _check_compression_options(self): :raises ValueError: If incompatible options are detected """ - if 'compression' in self.__iosettings: - if 'compression_opts' in self.__iosettings: - if self.__iosettings['compression'] == 'gzip': - if self.__iosettings['compression_opts'] not in range(10): - raise ValueError("GZIP compression_opts setting must be an integer from 0-9, " - "not " + str(self.__iosettings['compression_opts'])) - elif self.__iosettings['compression'] == 'lzf': - if self.__iosettings['compression_opts'] is not None: + if "compression" in self.__iosettings: + if "compression_opts" in self.__iosettings: + if self.__iosettings["compression"] == "gzip": + if self.__iosettings["compression_opts"] not in range(10): + raise ValueError( + "GZIP compression_opts setting must be an integer from 0-9, not " + + str(self.__iosettings["compression_opts"]) + ) + elif self.__iosettings["compression"] == "lzf": + if self.__iosettings["compression_opts"] is not None: raise ValueError("LZF compression filter accepts no compression_opts") - elif self.__iosettings['compression'] == 'szip': + elif self.__iosettings["compression"] == "szip": szip_opts_error = False # Check that we have a tuple - szip_opts_error |= not isinstance(self.__iosettings['compression_opts'], tuple) + szip_opts_error |= not isinstance(self.__iosettings["compression_opts"], tuple) # Check that we have a tuple of the right length and correct settings if not szip_opts_error: try: - szmethod, szpix = self.__iosettings['compression_opts'] - szip_opts_error |= (szmethod not in ('ec', 'nn')) - szip_opts_error |= (not (0 < szpix <= 32 and szpix % 2 == 0)) + szmethod, szpix = self.__iosettings["compression_opts"] + szip_opts_error |= szmethod not in ("ec", "nn") + szip_opts_error |= not (0 < szpix <= 32 and szpix % 2 == 0) except ValueError: # ValueError is raised if tuple does not have the right length to unpack szip_opts_error = True if szip_opts_error: - raise ValueError("SZIP compression filter compression_opts" - " must be a 2-tuple ('ec'|'nn', even integer 0-32).") + raise ValueError( + "SZIP compression filter compression_opts must be a 2-tuple ('ec'|'nn', even integer 0-32)." + ) # Warn if compressor other than gzip is being used - if self.__iosettings['compression'] not in ['gzip', h5py_filters.h5z.FILTER_DEFLATE]: - warnings.warn(str(self.__iosettings['compression']) + " compression may not be available " - "on all installations of HDF5. Use of gzip is recommended to ensure portability of " - "the generated HDF5 files.") + if self.__iosettings["compression"] not in [ + "gzip", + h5py_filters.h5z.FILTER_DEFLATE, + ]: + warnings.warn( + str(self.__iosettings["compression"]) + + " compression may not be available on all installations of HDF5." + " Use of gzip is recommended to ensure portability of the generated HDF5 files." + ) @staticmethod def filter_available(filter, allow_plugin_filters): @@ -614,8 +683,10 @@ def filter_available(filter, allow_plugin_filters): if type(filter) == int: if h5py_filters.h5z.filter_avail(filter): filter_info = h5py_filters.h5z.get_filter_info(filter) - if filter_info == (h5py_filters.h5z.FILTER_CONFIG_DECODE_ENABLED + - h5py_filters.h5z.FILTER_CONFIG_ENCODE_ENABLED): + if filter_info == ( + h5py_filters.h5z.FILTER_CONFIG_DECODE_ENABLED + + h5py_filters.h5z.FILTER_CONFIG_ENCODE_ENABLED + ): return True return False else: diff --git a/src/hdmf/backends/hdf5/h5tools.py b/src/hdmf/backends/hdf5/h5tools.py index 7767d234a..3b2111769 100644 --- a/src/hdmf/backends/hdf5/h5tools.py +++ b/src/hdmf/backends/hdf5/h5tools.py @@ -3,68 +3,123 @@ import warnings from collections import deque from functools import partial -from pathlib import Path, PurePosixPath as pp +from pathlib import Path +from pathlib import PurePosixPath as pp -import numpy as np import h5py -from h5py import File, Group, Dataset, special_dtype, SoftLink, ExternalLink, Reference, RegionReference, check_dtype - -from .h5_utils import (BuilderH5ReferenceDataset, BuilderH5RegionDataset, BuilderH5TableDataset, H5DataIO, - H5SpecReader, H5SpecWriter, HDF5IODataChunkIteratorQueue) -from ..io import HDMFIO -from ..errors import UnsupportedOperation -from ..warnings import BrokenLinkWarning -from ...build import (Builder, GroupBuilder, DatasetBuilder, LinkBuilder, BuildManager, RegionBuilder, - ReferenceBuilder, TypeMap, ObjectMapper) +import numpy as np +from h5py import ( + Dataset, + ExternalLink, + File, + Group, + Reference, + RegionReference, + SoftLink, + check_dtype, + special_dtype, +) + +from ...build import ( + Builder, + BuildManager, + DatasetBuilder, + GroupBuilder, + LinkBuilder, + ObjectMapper, + ReferenceBuilder, + RegionBuilder, + TypeMap, +) from ...container import Container from ...data_utils import AbstractDataChunkIterator -from ...spec import RefSpec, DtypeSpec, NamespaceCatalog -from ...utils import docval, getargs, popargs, get_data_shape, get_docval, StrDataset +from ...spec import DtypeSpec, NamespaceCatalog, RefSpec +from ...utils import StrDataset, docval, get_data_shape, get_docval, getargs, popargs +from ..errors import UnsupportedOperation +from ..io import HDMFIO from ..utils import NamespaceToBuilderHelper, WriteStatusTracker - -ROOT_NAME = 'root' -SPEC_LOC_ATTR = '.specloc' +from ..warnings import BrokenLinkWarning +from .h5_utils import ( + BuilderH5ReferenceDataset, + BuilderH5RegionDataset, + BuilderH5TableDataset, + H5DataIO, + H5SpecReader, + H5SpecWriter, + HDF5IODataChunkIteratorQueue, +) + +ROOT_NAME = "root" +SPEC_LOC_ATTR = ".specloc" H5_TEXT = special_dtype(vlen=str) H5_BINARY = special_dtype(vlen=bytes) H5_REF = special_dtype(ref=Reference) H5_REGREF = special_dtype(ref=RegionReference) -H5PY_3 = h5py.__version__.startswith('3') +H5PY_3 = h5py.__version__.startswith("3") class HDF5IO(HDMFIO): - - __ns_spec_path = 'namespace' # path to the namespace dataset within a namespace group - - @docval({'name': 'path', 'type': (str, Path), 'doc': 'the path to the HDF5 file', 'default': None}, - {'name': 'mode', 'type': str, - 'doc': ('the mode to open the HDF5 file with, one of ("w", "r", "r+", "a", "w-", "x"). ' - 'See `h5py.File `_ for ' - 'more details.'), - 'default': 'r'}, - {'name': 'manager', 'type': (TypeMap, BuildManager), - 'doc': 'the BuildManager or a TypeMap to construct a BuildManager to use for I/O', 'default': None}, - {'name': 'comm', 'type': 'Intracomm', - 'doc': 'the MPI communicator to use for parallel I/O', 'default': None}, - {'name': 'file', 'type': [File, "S3File"], 'doc': 'a pre-existing h5py.File object', 'default': None}, - {'name': 'driver', 'type': str, 'doc': 'driver for h5py to use when opening HDF5 file', 'default': None}) + __ns_spec_path = "namespace" # path to the namespace dataset within a namespace group + + @docval( + { + "name": "path", + "type": (str, Path), + "doc": "the path to the HDF5 file", + "default": None, + }, + { + "name": "mode", + "type": str, + "doc": ( + 'the mode to open the HDF5 file with, one of ("w", "r", "r+", "a", "w-", "x"). See `h5py.File' + " `_ for more details." + ), + "default": "r", + }, + { + "name": "manager", + "type": (TypeMap, BuildManager), + "doc": "the BuildManager or a TypeMap to construct a BuildManager to use for I/O", + "default": None, + }, + { + "name": "comm", + "type": "Intracomm", + "doc": "the MPI communicator to use for parallel I/O", + "default": None, + }, + { + "name": "file", + "type": [File, "S3File"], + "doc": "a pre-existing h5py.File object", + "default": None, + }, + { + "name": "driver", + "type": str, + "doc": "driver for h5py to use when opening HDF5 file", + "default": None, + }, + ) def __init__(self, **kwargs): - """Open an HDF5 file for IO. - """ - self.logger = logging.getLogger('%s.%s' % (self.__class__.__module__, self.__class__.__qualname__)) - path, manager, mode, comm, file_obj, driver = popargs('path', 'manager', 'mode', 'comm', 'file', 'driver', - kwargs) + """Open an HDF5 file for IO.""" + self.logger = logging.getLogger("%s.%s" % (self.__class__.__module__, self.__class__.__qualname__)) + path, manager, mode, comm, file_obj, driver = popargs( + "path", "manager", "mode", "comm", "file", "driver", kwargs + ) self.__open_links = [] # keep track of other files opened from links in this file self.__file = None # This will be set below, but set to None first in case an error occurs and we need to close path = self.__check_path_file_obj(path, file_obj) - if file_obj is None and not os.path.exists(path) and (mode == 'r' or mode == 'r+') and driver != 'ros3': + if file_obj is None and not os.path.exists(path) and (mode == "r" or mode == "r+") and driver != "ros3": msg = "Unable to open file %s in '%s' mode. File does not exist." % (path, mode) raise UnsupportedOperation(msg) - if file_obj is None and os.path.exists(path) and (mode == 'w-' or mode == 'x'): + if file_obj is None and os.path.exists(path) and (mode == "w-" or mode == "x"): msg = "Unable to open file %s in '%s' mode. File already exists." % (path, mode) raise UnsupportedOperation(msg) @@ -77,8 +132,8 @@ def __init__(self, **kwargs): self.__mode = mode self.__file = file_obj super().__init__(manager, source=path) # NOTE: source is not set if path is None and file_obj is passed - self.__built = dict() # keep track of each builder for each dataset/group/link for each file - self.__read = dict() # keep track of which files have been read. Key is the filename value is the builder + self.__built = dict() # keep track of each builder for each dataset/group/link for each file + self.__read = dict() # keep track of which files have been read. Key is the filename value is the builder self.__ref_queue = deque() # a queue of the references that need to be added self.__dci_queue = HDF5IODataChunkIteratorQueue() # a queue of DataChunkIterators that need to be exhausted ObjectMapper.no_convert(Dataset) @@ -107,8 +162,10 @@ def __check_path_file_obj(cls, path, file_obj): if path is not None and file_obj is not None: # consistency check if os.path.abspath(file_obj.filename) != os.path.abspath(path): - msg = ("You argued '%s' as this object's path, but supplied a file with filename: %s" - % (path, file_obj.filename)) + msg = "You argued '%s' as this object's path, but supplied a file with filename: %s" % ( + path, + file_obj.filename, + ) raise ValueError(msg) return path @@ -121,19 +178,46 @@ def __resolve_file_obj(cls, path, file_obj, driver): file_kwargs = dict() if driver is not None: file_kwargs.update(driver=driver) - file_obj = File(path, 'r', **file_kwargs) + file_obj = File(path, "r", **file_kwargs) return file_obj @classmethod - @docval({'name': 'namespace_catalog', 'type': (NamespaceCatalog, TypeMap), - 'doc': 'the NamespaceCatalog or TypeMap to load namespaces into'}, - {'name': 'path', 'type': (str, Path), 'doc': 'the path to the HDF5 file', 'default': None}, - {'name': 'namespaces', 'type': list, 'doc': 'the namespaces to load', 'default': None}, - {'name': 'file', 'type': File, 'doc': 'a pre-existing h5py.File object', 'default': None}, - {'name': 'driver', 'type': str, 'doc': 'driver for h5py to use when opening HDF5 file', 'default': None}, - returns=("dict mapping the names of the loaded namespaces to a dict mapping included namespace names and " - "the included data types"), - rtype=dict) + @docval( + { + "name": "namespace_catalog", + "type": (NamespaceCatalog, TypeMap), + "doc": "the NamespaceCatalog or TypeMap to load namespaces into", + }, + { + "name": "path", + "type": (str, Path), + "doc": "the path to the HDF5 file", + "default": None, + }, + { + "name": "namespaces", + "type": list, + "doc": "the namespaces to load", + "default": None, + }, + { + "name": "file", + "type": File, + "doc": "a pre-existing h5py.File object", + "default": None, + }, + { + "name": "driver", + "type": str, + "doc": "driver for h5py to use when opening HDF5 file", + "default": None, + }, + returns=( + "dict mapping the names of the loaded namespaces to a dict mapping included" + " namespace names and the included data types" + ), + rtype=dict, + ) def load_namespaces(cls, **kwargs): """Load cached namespaces from a file. @@ -144,7 +228,8 @@ def load_namespaces(cls, **kwargs): :raises ValueError: if both `path` and `file` are supplied but `path` is not the same as the path of `file`. """ namespace_catalog, path, namespaces, file_obj, driver = popargs( - 'namespace_catalog', 'path', 'namespaces', 'file', 'driver', kwargs) + "namespace_catalog", "path", "namespaces", "file", "driver", kwargs + ) open_file_obj = cls.__resolve_file_obj(path, file_obj, driver) if file_obj is None: # need to close the file object that we just opened @@ -175,8 +260,8 @@ def __load_namespaces(cls, namespace_catalog, namespaces, file_obj): # for each namespace in the 'namespace' dataset, track all included namespaces (dependencies) for spec_ns in reader.read_namespace(cls.__ns_spec_path): deps[ns] = list() - for s in spec_ns['schema']: - dep = s.get('namespace') + for s in spec_ns["schema"]: + dep = s.get("namespace") if dep is not None: deps[ns].append(dep) @@ -197,10 +282,13 @@ def __check_specloc(cls, file_obj): return True @classmethod - @docval({'name': 'path', 'type': (str, Path), 'doc': 'the path to the HDF5 file', 'default': None}, - {'name': 'file', 'type': File, 'doc': 'a pre-existing h5py.File object', 'default': None}, - {'name': 'driver', 'type': str, 'doc': 'driver for h5py to use when opening HDF5 file', 'default': None}, - returns="dict mapping names to versions of the namespaces in the file", rtype=dict) + @docval( + {"name": "path", "type": (str, Path), "doc": "the path to the HDF5 file", "default": None}, + {"name": "file", "type": File, "doc": "a pre-existing h5py.File object", "default": None}, + {"name": "driver", "type": str, "doc": "driver for h5py to use when opening HDF5 file", "default": None}, + returns="dict mapping names to versions of the namespaces in the file", + rtype=dict, + ) def get_namespaces(cls, **kwargs): """Get the names and versions of the cached namespaces from a file. @@ -213,7 +301,7 @@ def get_namespaces(cls, **kwargs): :raises ValueError: if both `path` and `file` are supplied but `path` is not the same as the path of `file`. """ - path, file_obj, driver = popargs('path', 'file', 'driver', kwargs) + path, file_obj, driver = popargs("path", "file", "driver", kwargs) open_file_obj = cls.__resolve_file_obj(path, file_obj, driver) if file_obj is None: # need to close the file object that we just opened @@ -241,13 +329,13 @@ def __get_namespaces(cls, file_obj): if len(version_names) > 1: # prior to HDMF 1.6.1, extensions without a version were written under the group name "unversioned" # make sure that if there is another group representing a newer version, that is read instead - if 'unversioned' in version_names: - version_names.remove('unversioned') + if "unversioned" in version_names: + version_names.remove("unversioned") if len(version_names) > 1: # as of HDMF 1.6.1, extensions without a version are written under the group name "None" # make sure that if there is another group representing a newer version, that is read instead - if 'None' in version_names: - version_names.remove('None') + if "None" in version_names: + version_names.remove("None") used_version_names[ns] = version_names[-1] # save the largest in alphanumeric order return used_version_names @@ -284,13 +372,36 @@ def __order_deps_aux(cls, order, deps, key): order.append(key) @classmethod - @docval({'name': 'source_filename', 'type': str, 'doc': 'the path to the HDF5 file to copy'}, - {'name': 'dest_filename', 'type': str, 'doc': 'the name of the destination file'}, - {'name': 'expand_external', 'type': bool, 'doc': 'expand external links into new objects', 'default': True}, - {'name': 'expand_refs', 'type': bool, 'doc': 'copy objects which are pointed to by reference', - 'default': False}, - {'name': 'expand_soft', 'type': bool, 'doc': 'expand soft links into new objects', 'default': False} - ) + @docval( + { + "name": "source_filename", + "type": str, + "doc": "the path to the HDF5 file to copy", + }, + { + "name": "dest_filename", + "type": str, + "doc": "the name of the destination file", + }, + { + "name": "expand_external", + "type": bool, + "doc": "expand external links into new objects", + "default": True, + }, + { + "name": "expand_refs", + "type": bool, + "doc": "copy objects which are pointed to by reference", + "default": False, + }, + { + "name": "expand_soft", + "type": bool, + "doc": "expand soft links into new objects", + "default": False, + }, + ) def copy_file(self, **kwargs): """ Convenience function to copy an HDF5 file while allowing external links to be resolved. @@ -308,52 +419,70 @@ def copy_file(self, **kwargs): """ - warnings.warn("The copy_file class method is no longer supported and may be removed in a future version of " - "HDMF. Please use the export method or h5py.File.copy method instead.", DeprecationWarning) - - source_filename, dest_filename, expand_external, expand_refs, expand_soft = getargs('source_filename', - 'dest_filename', - 'expand_external', - 'expand_refs', - 'expand_soft', - kwargs) - source_file = File(source_filename, 'r') - dest_file = File(dest_filename, 'w') + warnings.warn( + ( + "The copy_file class method is no longer supported and may be removed " + "in a future version of HDMF. Please use the export method or h5py.File.copy method instead." + ), + DeprecationWarning, + ) + + source_file = File(kwargs["source_filename"], "r") + dest_file = File(kwargs["dest_filename"], "w") for objname in source_file["/"].keys(): - source_file.copy(source=objname, - dest=dest_file, - name=objname, - expand_external=expand_external, - expand_refs=expand_refs, - expand_soft=expand_soft, - shallow=False, - without_attrs=False, - ) - for objname in source_file['/'].attrs: - dest_file['/'].attrs[objname] = source_file['/'].attrs[objname] + source_file.copy( + source=objname, + dest=dest_file, + name=objname, + expand_external=kwargs["expand_external"], + expand_refs=kwargs["expand_refs"], + expand_soft=kwargs["expand_soft"], + shallow=False, + without_attrs=False, + ) + for objname in source_file["/"].attrs: + dest_file["/"].attrs[objname] = source_file["/"].attrs[objname] source_file.close() dest_file.close() - @docval({'name': 'container', 'type': Container, 'doc': 'the Container object to write'}, - {'name': 'cache_spec', 'type': bool, - 'doc': ('If True (default), cache specification to file (highly recommended). If False, do not cache ' - 'specification to file. The appropriate specification will then need to be loaded prior to ' - 'reading the file.'), - 'default': True}, - {'name': 'link_data', 'type': bool, - 'doc': 'If True (default), create external links to HDF5 Datasets. If False, copy HDF5 Datasets.', - 'default': True}, - {'name': 'exhaust_dci', 'type': bool, - 'doc': 'If True (default), exhaust DataChunkIterators one at a time. If False, exhaust them concurrently.', - 'default': True}) + @docval( + { + "name": "container", + "type": Container, + "doc": "the Container object to write", + }, + { + "name": "cache_spec", + "type": bool, + "doc": ( + "If True (default), cache specification to file (highly recommended). " + "If False, do not cache specification to file. The appropriate " + "specification will then need to be loaded prior to reading the file." + ), + "default": True, + }, + { + "name": "link_data", + "type": bool, + "doc": "If True (default), create external links to HDF5 Datasets. If False, copy HDF5 Datasets.", + "default": True, + }, + { + "name": "exhaust_dci", + "type": bool, + "doc": "If True (default), exhaust DataChunkIterators one at a time. If False, exhaust them concurrently.", + "default": True, + }, + ) def write(self, **kwargs): """Write the container to an HDF5 file.""" - if self.__mode == 'r': - raise UnsupportedOperation(("Cannot write to file %s in mode '%s'. " - "Please use mode 'r+', 'w', 'w-', 'x', or 'a'") - % (self.source, self.__mode)) + if self.__mode == "r": + raise UnsupportedOperation( + "Cannot write to file %s in mode '%s'. Please use mode 'r+', 'w', 'w-', 'x', or 'a'" + % (self.source, self.__mode) + ) - cache_spec = popargs('cache_spec', kwargs) + cache_spec = popargs("cache_spec", kwargs) super().write(**kwargs) if cache_spec: self.__cache_spec() @@ -364,14 +493,14 @@ def __cache_spec(self): if ref is not None: spec_group = self.__file[ref] else: - path = 'specifications' # do something to figure out where the specifications should go + path = "specifications" # do something to figure out where the specifications should go spec_group = self.__file.require_group(path) self.__file.attrs[SPEC_LOC_ATTR] = spec_group.ref ns_catalog = self.manager.namespace_catalog for ns_name in ns_catalog.namespaces: ns_builder = NamespaceToBuilderHelper.convert_namespace(ns_catalog, ns_name) namespace = ns_catalog.get_namespace(ns_name) - group_name = '%s/%s' % (ns_name, namespace.version) + group_name = "%s/%s" % (ns_name, namespace.version) if group_name in spec_group: continue ns_group = spec_group.create_group(group_name) @@ -379,15 +508,32 @@ def __cache_spec(self): ns_builder.export(self.__ns_spec_path, writer=writer) _export_args = ( - {'name': 'src_io', 'type': 'HDMFIO', 'doc': 'the HDMFIO object for reading the data to export'}, - {'name': 'container', 'type': Container, - 'doc': ('the Container object to export. If None, then the entire contents of the HDMFIO object will be ' - 'exported'), - 'default': None}, - {'name': 'write_args', 'type': dict, 'doc': 'arguments to pass to :py:meth:`write_builder`', - 'default': None}, - {'name': 'cache_spec', 'type': bool, 'doc': 'whether to cache the specification to file', - 'default': True} + { + "name": "src_io", + "type": "HDMFIO", + "doc": "the HDMFIO object for reading the data to export", + }, + { + "name": "container", + "type": Container, + "doc": ( + "the Container object to export. If None, then the entire contents of the HDMFIO object will be " + "exported" + ), + "default": None, + }, + { + "name": "write_args", + "type": dict, + "doc": "arguments to pass to :py:meth:`write_builder`", + "default": None, + }, + { + "name": "cache_spec", + "type": bool, + "doc": "whether to cache the specification to file", + "default": True, + }, # clear_cache is an arg on HDMFIO.export but it is intended for internal usage # so it is not available on HDF5IO ) @@ -398,40 +544,43 @@ def export(self, **kwargs): See :py:meth:`hdmf.backends.io.HDMFIO.export` for more details. """ - if self.__mode != 'w': - raise UnsupportedOperation("Cannot export to file %s in mode '%s'. Please use mode 'w'." - % (self.source, self.__mode)) + if self.__mode != "w": + raise UnsupportedOperation( + "Cannot export to file %s in mode '%s'. Please use mode 'w'." % (self.source, self.__mode) + ) - src_io = getargs('src_io', kwargs) - write_args, cache_spec = popargs('write_args', 'cache_spec', kwargs) + src_io = getargs("src_io", kwargs) + write_args, cache_spec = popargs("write_args", "cache_spec", kwargs) if write_args is None: write_args = dict() - if not isinstance(src_io, HDF5IO) and write_args.get('link_data', True): - raise UnsupportedOperation("Cannot export from non-HDF5 backend %s to HDF5 with write argument " - "link_data=True." % src_io.__class__.__name__) + if not isinstance(src_io, HDF5IO) and write_args.get("link_data", True): + raise UnsupportedOperation( + "Cannot export from non-HDF5 backend %s to HDF5 with write argument link_data=True." + % src_io.__class__.__name__ + ) - write_args['export_source'] = os.path.abspath(src_io.source) if src_io.source is not None else None + write_args["export_source"] = os.path.abspath(src_io.source) if src_io.source is not None else None ckwargs = kwargs.copy() - ckwargs['write_args'] = write_args - if not write_args.get('link_data', True): - ckwargs['clear_cache'] = True + ckwargs["write_args"] = write_args + if not write_args.get("link_data", True): + ckwargs["clear_cache"] = True super().export(**ckwargs) if cache_spec: # add any namespaces from the src_io that have not yet been loaded for namespace in src_io.manager.namespace_catalog.namespaces: if namespace not in self.manager.namespace_catalog.namespaces: self.manager.namespace_catalog.add_namespace( - name=namespace, - namespace=src_io.manager.namespace_catalog.get_namespace(namespace) + name=namespace, namespace=src_io.manager.namespace_catalog.get_namespace(namespace) ) self.__cache_spec() @classmethod - @docval({'name': 'path', 'type': str, 'doc': 'the path to the destination HDF5 file'}, - {'name': 'comm', 'type': 'Intracomm', 'doc': 'the MPI communicator to use for parallel I/O', - 'default': None}, - *_export_args) # NOTE: src_io is required and is the second positional argument + @docval( + {"name": "path", "type": str, "doc": "the path to the destination HDF5 file"}, + {"name": "comm", "type": "Intracomm", "doc": "the MPI communicator to use for parallel I/O", "default": None}, + *_export_args, + ) # NOTE: src_io is required and is the second positional argument def export_io(self, **kwargs): """Export from one backend to HDF5 (class method). @@ -448,23 +597,25 @@ def export_io(self, **kwargs): See :py:meth:`export` for more details. """ - path, comm = popargs('path', 'comm', kwargs) + path, comm = popargs("path", "comm", kwargs) - with HDF5IO(path=path, comm=comm, mode='w') as write_io: + with HDF5IO(path=path, comm=comm, mode="w") as write_io: write_io.export(**kwargs) def read(self, **kwargs): - if self.__mode == 'w' or self.__mode == 'w-' or self.__mode == 'x': - raise UnsupportedOperation("Cannot read from file %s in mode '%s'. Please use mode 'r', 'r+', or 'a'." - % (self.source, self.__mode)) + if self.__mode == "w" or self.__mode == "w-" or self.__mode == "x": + raise UnsupportedOperation( + "Cannot read from file %s in mode '%s'. Please use mode 'r', 'r+', or 'a'." % (self.source, self.__mode) + ) try: return super().read(**kwargs) except UnsupportedOperation as e: - if str(e) == 'Cannot build data. There are no values.': # pragma: no cover - raise UnsupportedOperation("Cannot read data from file %s in mode '%s'. There are no values." - % (self.source, self.__mode)) + if str(e) == "Cannot build data. There are no values.": # pragma: no cover + raise UnsupportedOperation( + "Cannot read data from file %s in mode '%s'. There are no values." % (self.source, self.__mode) + ) - @docval(returns='a GroupBuilder representing the data object', rtype='GroupBuilder') + @docval(returns="a GroupBuilder representing the data object", rtype="GroupBuilder") def read_builder(self): """ Read data and return the GroupBuilder representing it. @@ -535,31 +686,31 @@ def __get_built(self, fpath, id): else: return None - @docval({'name': 'h5obj', 'type': (Dataset, Group), - 'doc': 'the HDF5 object to the corresponding Builder object for'}) + @docval({"name": "h5obj", "type": (Dataset, Group), "doc": "the HDF5 object to the corresponding Builder object"}) def get_builder(self, **kwargs): """ Get the builder for the corresponding h5py Group or Dataset :raises ValueError: When no builder has been constructed yet for the given h5py object """ - h5obj = getargs('h5obj', kwargs) + h5obj = getargs("h5obj", kwargs) fpath = h5obj.file.filename builder = self.__get_built(fpath, h5obj.id) if builder is None: - msg = '%s:%s has not been built' % (fpath, h5obj.name) + msg = "%s:%s has not been built" % (fpath, h5obj.name) raise ValueError(msg) return builder - @docval({'name': 'h5obj', 'type': (Dataset, Group), - 'doc': 'the HDF5 object to the corresponding Container/Data object for'}) + @docval( + {"name": "h5obj", "type": (Dataset, Group), "doc": "the HDF5 object to the corresponding Container/Data object"} + ) def get_container(self, **kwargs): """ Get the container for the corresponding h5py Group or Dataset :raises ValueError: When no builder has been constructed yet for the given h5py object """ - h5obj = getargs('h5obj', kwargs) + h5obj = getargs("h5obj", kwargs) builder = self.get_builder(h5obj) container = self.manager.construct(builder) return container @@ -569,12 +720,12 @@ def __read_group(self, h5obj, name=None, ignore=set()): "attributes": self.__read_attrs(h5obj), "groups": dict(), "datasets": dict(), - "links": dict() + "links": dict(), } - for key, val in kwargs['attributes'].items(): + for key, val in kwargs["attributes"].items(): if isinstance(val, bytes): - kwargs['attributes'][key] = val.decode('UTF-8') + kwargs["attributes"][key] = val.decode("UTF-8") if name is None: name = str(os.path.basename(h5obj.name)) @@ -598,11 +749,11 @@ def __read_group(self, h5obj, name=None, ignore=set()): builder = self.__read_dataset(target_obj, builder_name) else: builder = self.__read_group(target_obj, builder_name, ignore=ignore) - self.__set_built(sub_h5obj.file.filename, target_obj.id, builder) + self.__set_built(sub_h5obj.file.filename, target_obj.id, builder) link_builder = LinkBuilder(builder=builder, name=k, source=os.path.abspath(h5obj.file.filename)) link_builder.location = h5obj.name self.__set_written(link_builder) - kwargs['links'][builder_name] = link_builder + kwargs["links"][builder_name] = link_builder if isinstance(link_type, ExternalLink): self.__open_links.append(sub_h5obj) else: @@ -611,19 +762,22 @@ def __read_group(self, h5obj, name=None, ignore=set()): read_method = None if isinstance(sub_h5obj, Dataset): read_method = self.__read_dataset - obj_type = kwargs['datasets'] + obj_type = kwargs["datasets"] else: read_method = partial(self.__read_group, ignore=ignore) - obj_type = kwargs['groups'] + obj_type = kwargs["groups"] if builder is None: builder = read_method(sub_h5obj) self.__set_built(sub_h5obj.file.filename, sub_h5obj.id, builder) obj_type[builder.name] = builder else: - warnings.warn('Path to Group altered/broken at ' + os.path.join(h5obj.name, k), BrokenLinkWarning) - kwargs['datasets'][k] = None + warnings.warn( + "Path to Group altered/broken at " + os.path.join(h5obj.name, k), + BrokenLinkWarning, + ) + kwargs["datasets"][k] = None continue - kwargs['source'] = os.path.abspath(h5obj.file.filename) + kwargs["source"] = os.path.abspath(h5obj.file.filename) ret = GroupBuilder(name, **kwargs) ret.location = os.path.dirname(h5obj.name) self.__set_written(ret) @@ -633,20 +787,20 @@ def __read_dataset(self, h5obj, name=None): kwargs = { "attributes": self.__read_attrs(h5obj), "dtype": h5obj.dtype, - "maxshape": h5obj.maxshape + "maxshape": h5obj.maxshape, } - for key, val in kwargs['attributes'].items(): + for key, val in kwargs["attributes"].items(): if isinstance(val, bytes): - kwargs['attributes'][key] = val.decode('UTF-8') + kwargs["attributes"][key] = val.decode("UTF-8") if name is None: name = str(os.path.basename(h5obj.name)) - kwargs['source'] = os.path.abspath(h5obj.file.filename) + kwargs["source"] = os.path.abspath(h5obj.file.filename) ndims = len(h5obj.shape) - if ndims == 0: # read scalar + if ndims == 0: # read scalar scalar = h5obj[()] if isinstance(scalar, bytes): - scalar = scalar.decode('UTF-8') + scalar = scalar.decode("UTF-8") if isinstance(scalar, Reference): # TODO (AJTRITT): This should call __read_ref to support Group references @@ -657,27 +811,27 @@ def __read_dataset(self, h5obj, name=None): d = RegionBuilder(scalar, target_builder) else: d = ReferenceBuilder(target_builder) - kwargs['data'] = d - kwargs['dtype'] = d.dtype + kwargs["data"] = d + kwargs["dtype"] = d.dtype else: kwargs["data"] = scalar else: d = None - if h5obj.dtype.kind == 'O' and len(h5obj) > 0: + if h5obj.dtype.kind == "O" and len(h5obj) > 0: elem1 = h5obj[tuple([0] * (h5obj.ndim - 1) + [0])] if isinstance(elem1, (str, bytes)): d = self._check_str_dtype(h5obj) elif isinstance(elem1, RegionReference): # read list of references d = BuilderH5RegionDataset(h5obj, self) - kwargs['dtype'] = d.dtype + kwargs["dtype"] = d.dtype elif isinstance(elem1, Reference): d = BuilderH5ReferenceDataset(h5obj, self) - kwargs['dtype'] = d.dtype - elif h5obj.dtype.kind == 'V': # table / compound data type + kwargs["dtype"] = d.dtype + elif h5obj.dtype.kind == "V": # table / compound data type cpd_dt = h5obj.dtype ref_cols = [check_dtype(ref=cpd_dt[i]) or check_dtype(vlen=cpd_dt[i]) for i in range(len(cpd_dt))] d = BuilderH5TableDataset(h5obj, self, ref_cols) - kwargs['dtype'] = HDF5IO.__compound_dtype_to_list(h5obj.dtype, d.dtype) + kwargs["dtype"] = HDF5IO.__compound_dtype_to_list(h5obj.dtype, d.dtype) else: d = h5obj kwargs["data"] = d @@ -688,8 +842,8 @@ def __read_dataset(self, h5obj, name=None): def _check_str_dtype(self, h5obj): dtype = h5obj.dtype - if dtype.kind == 'O': - if dtype.metadata.get('vlen') == str and H5PY_3: + if dtype.kind == "O": + if dtype.metadata.get("vlen") == str and H5PY_3: return StrDataset(h5obj, None) return h5obj @@ -697,13 +851,13 @@ def _check_str_dtype(self, h5obj): def __compound_dtype_to_list(cls, h5obj_dtype, dset_dtype): ret = [] for name, dtype in zip(h5obj_dtype.fields, dset_dtype): - ret.append({'name': name, 'dtype': dtype}) + ret.append({"name": name, "dtype": dtype}) return ret def __read_attrs(self, h5obj): ret = dict() for k, v in h5obj.attrs.items(): - if k == SPEC_LOC_ATTR: # ignore cached spec + if k == SPEC_LOC_ATTR: # ignore cached spec continue if isinstance(v, RegionReference): raise ValueError("cannot read region reference attributes yet") @@ -731,7 +885,7 @@ def open(self): open_flag = self.__mode kwargs = dict() if self.comm: - kwargs.update(driver='mpio', comm=self.comm) + kwargs.update(driver="mpio", comm=self.comm) if self.driver is not None: kwargs.update(driver=self.driver) @@ -777,19 +931,37 @@ def close_linked_files(self): finally: self.__open_links = [] - @docval({'name': 'builder', 'type': GroupBuilder, 'doc': 'the GroupBuilder object representing the HDF5 file'}, - {'name': 'link_data', 'type': bool, - 'doc': 'If not specified otherwise link (True) or copy (False) HDF5 Datasets', 'default': True}, - {'name': 'exhaust_dci', 'type': bool, - 'doc': 'exhaust DataChunkIterators one at a time. If False, exhaust them concurrently', - 'default': True}, - {'name': 'export_source', 'type': str, - 'doc': 'The source of the builders when exporting', 'default': None}) + @docval( + { + "name": "builder", + "type": GroupBuilder, + "doc": "the GroupBuilder object representing the HDF5 file", + }, + { + "name": "link_data", + "type": bool, + "doc": "If not specified otherwise link (True) or copy (False) HDF5 Datasets", + "default": True, + }, + { + "name": "exhaust_dci", + "type": bool, + "doc": "exhaust DataChunkIterators one at a time. If False, exhaust them concurrently", + "default": True, + }, + { + "name": "export_source", + "type": str, + "doc": "The source of the builders when exporting", + "default": None, + }, + ) def write_builder(self, **kwargs): - f_builder = popargs('builder', kwargs) - link_data, exhaust_dci, export_source = getargs('link_data', 'exhaust_dci', 'export_source', kwargs) - self.logger.debug("Writing GroupBuilder '%s' to path '%s' with kwargs=%s" - % (f_builder.name, self.source, kwargs)) + f_builder = popargs("builder", kwargs) + link_data, exhaust_dci, export_source = getargs("link_data", "exhaust_dci", "export_source", kwargs) + self.logger.debug( + "Writing GroupBuilder '%s' to path '%s' with kwargs=%s" % (f_builder.name, self.source, kwargs) + ) for name, gbldr in f_builder.groups.items(): self.write_group(self.__file, gbldr, **kwargs) for name, dbldr in f_builder.datasets.items(): @@ -800,28 +972,30 @@ def write_builder(self, **kwargs): self.__add_refs() self.__dci_queue.exhaust_queue() self.__set_written(f_builder) - self.logger.debug("Done writing %s '%s' to path '%s'" % - (f_builder.__class__.__qualname__, f_builder.name, self.source)) + self.logger.debug( + "Done writing %s '%s' to path '%s'" % (f_builder.__class__.__qualname__, f_builder.name, self.source) + ) def __add_refs(self): - ''' + """ Add all references in the file. References get queued to be added at the end of write. This is because the current traversal algorithm (i.e. iterating over GroupBuilder items) does not happen in a guaranteed order. We need to figure out what objects will be references, and then write them after we write everything else. - ''' + """ failed = set() while len(self.__ref_queue) > 0: call = self.__ref_queue.popleft() - self.logger.debug("Adding reference with call id %d from queue (length %d)" - % (id(call), len(self.__ref_queue))) + self.logger.debug( + "Adding reference with call id %d from queue (length %d)" % (id(call), len(self.__ref_queue)) + ) try: call() except KeyError: if id(call) in failed: - raise RuntimeError('Unable to resolve reference') + raise RuntimeError("Unable to resolve reference") self.logger.debug("Adding reference with call id %d failed. Appending call to queue" % id(call)) failed.add(id(call)) self.__ref_queue.append(call) @@ -834,14 +1008,14 @@ def get_type(cls, data): return H5_BINARY elif isinstance(data, Container): return H5_REF - elif not hasattr(data, '__len__'): + elif not hasattr(data, "__len__"): return type(data) else: if len(data) == 0: - if hasattr(data, 'dtype'): + if hasattr(data, "dtype"): return data.dtype else: - raise ValueError('cannot determine type for empty data') + raise ValueError("cannot determine type for empty data") return cls.get_type(data[0]) __dtypes = { @@ -893,19 +1067,27 @@ def __resolve_dtype_helper__(cls, dtype): elif isinstance(dtype, str): return cls.__dtypes.get(dtype) elif isinstance(dtype, dict): - return cls.__dtypes.get(dtype['reftype']) + return cls.__dtypes.get(dtype["reftype"]) elif isinstance(dtype, np.dtype): # NOTE: some dtypes may not be supported, but we need to support writing of read-in compound types return dtype else: - return np.dtype([(x['name'], cls.__resolve_dtype_helper__(x['dtype'])) for x in dtype]) - - @docval({'name': 'obj', 'type': (Group, Dataset), 'doc': 'the HDF5 object to add attributes to'}, - {'name': 'attributes', - 'type': dict, - 'doc': 'a dict containing the attributes on the Group or Dataset, indexed by attribute name'}) + return np.dtype([(x["name"], cls.__resolve_dtype_helper__(x["dtype"])) for x in dtype]) + + @docval( + { + "name": "obj", + "type": (Group, Dataset), + "doc": "the HDF5 object to add attributes to", + }, + { + "name": "attributes", + "type": dict, + "doc": "a dict containing the attributes on the Group or Dataset, indexed by attribute name", + }, + ) def set_attributes(self, **kwargs): - obj, attributes = getargs('obj', 'attributes', kwargs) + obj, attributes = getargs("obj", "attributes", kwargs) for key, value in attributes.items(): try: if isinstance(value, (set, list, tuple)): @@ -917,50 +1099,82 @@ def set_attributes(self, **kwargs): self.__queue_ref(self._make_attr_ref_filler(obj, key, tmp)) else: value = np.array(value) - self.logger.debug("Setting %s '%s' attribute '%s' to %s" - % (obj.__class__.__name__, obj.name, key, value.__class__.__name__)) + self.logger.debug( + "Setting %s '%s' attribute '%s' to %s" + % (obj.__class__.__name__, obj.name, key, value.__class__.__name__) + ) obj.attrs[key] = value - elif isinstance(value, (Container, Builder, ReferenceBuilder)): # a reference + elif isinstance(value, (Container, Builder, ReferenceBuilder)): # a reference self.__queue_ref(self._make_attr_ref_filler(obj, key, value)) else: - self.logger.debug("Setting %s '%s' attribute '%s' to %s" - % (obj.__class__.__name__, obj.name, key, value.__class__.__name__)) - if isinstance(value, np.ndarray) and value.dtype.kind == 'U': + self.logger.debug( + "Setting %s '%s' attribute '%s' to %s" + % (obj.__class__.__name__, obj.name, key, value.__class__.__name__) + ) + if isinstance(value, np.ndarray) and value.dtype.kind == "U": value = np.array(value, dtype=H5_TEXT) - obj.attrs[key] = value # a regular scalar + obj.attrs[key] = value # a regular scalar except Exception as e: msg = "unable to write attribute '%s' on object '%s'" % (key, obj.name) raise RuntimeError(msg) from e def _make_attr_ref_filler(self, obj, key, value): - ''' - Make the callable for setting references to attributes - ''' - self.logger.debug("Queueing set %s '%s' attribute '%s' to %s" - % (obj.__class__.__name__, obj.name, key, value.__class__.__name__)) + """ + Make the callable for setting references to attributes + """ + self.logger.debug( + "Queueing set %s '%s' attribute '%s' to %s" + % (obj.__class__.__name__, obj.name, key, value.__class__.__name__) + ) if isinstance(value, (tuple, list)): + def _filler(): ret = list() for item in value: ret.append(self.__get_ref(item)) obj.attrs[key] = ret + else: + def _filler(): obj.attrs[key] = self.__get_ref(value) + return _filler - @docval({'name': 'parent', 'type': Group, 'doc': 'the parent HDF5 object'}, - {'name': 'builder', 'type': GroupBuilder, 'doc': 'the GroupBuilder to write'}, - {'name': 'link_data', 'type': bool, - 'doc': 'If not specified otherwise link (True) or copy (False) HDF5 Datasets', 'default': True}, - {'name': 'exhaust_dci', 'type': bool, - 'doc': 'exhaust DataChunkIterators one at a time. If False, exhaust them concurrently', - 'default': True}, - {'name': 'export_source', 'type': str, - 'doc': 'The source of the builders when exporting', 'default': None}, - returns='the Group that was created', rtype='Group') + @docval( + { + "name": "parent", + "type": Group, + "doc": "the parent HDF5 object", + }, + { + "name": "builder", + "type": GroupBuilder, + "doc": "the GroupBuilder to write", + }, + { + "name": "link_data", + "type": bool, + "doc": "If not specified otherwise link (True) or copy (False) HDF5 Datasets", + "default": True, + }, + { + "name": "exhaust_dci", + "type": bool, + "doc": "exhaust DataChunkIterators one at a time. If False, exhaust them concurrently", + "default": True, + }, + { + "name": "export_source", + "type": str, + "doc": "The source of the builders when exporting", + "default": None, + }, + returns="the Group that was created", + rtype="Group", + ) def write_group(self, **kwargs): - parent, builder = popargs('parent', 'builder', kwargs) + parent, builder = popargs("parent", "builder", kwargs) self.logger.debug("Writing GroupBuilder '%s' to parent group '%s'" % (builder.name, parent.name)) if self.get_written(builder): self.logger.debug(" GroupBuilder '%s' is already written" % builder.name) @@ -983,7 +1197,11 @@ def write_group(self, **kwargs): links = builder.links if links: for link_name, sub_builder in links.items(): - self.write_link(group, sub_builder, export_source=kwargs.get("export_source")) + self.write_link( + group, + sub_builder, + export_source=kwargs.get("export_source"), + ) attributes = builder.attributes self.set_attributes(group, attributes) self.__set_written(builder) @@ -1009,13 +1227,15 @@ def __get_path(self, builder): path = "%s%s" % (delim, delim.join(reversed(names))) return path - @docval({'name': 'parent', 'type': Group, 'doc': 'the parent HDF5 object'}, - {'name': 'builder', 'type': LinkBuilder, 'doc': 'the LinkBuilder to write'}, - {'name': 'export_source', 'type': str, - 'doc': 'The source of the builders when exporting', 'default': None}, - returns='the Link that was created', rtype='Link') + @docval( + {"name": "parent", "type": Group, "doc": "the parent HDF5 object"}, + {"name": "builder", "type": LinkBuilder, "doc": "the LinkBuilder to write"}, + {"name": "export_source", "type": str, "doc": "The source of the builders when exporting", "default": None}, + returns="the Link that was created", + rtype="Link", + ) def write_link(self, **kwargs): - parent, builder, export_source = getargs('parent', 'builder', 'export_source', kwargs) + parent, builder, export_source = getargs("parent", "builder", "export_source", kwargs) self.logger.debug("Writing LinkBuilder '%s' to parent group '%s'" % (builder.name, parent.name)) if self.get_written(builder): self.logger.debug(" LinkBuilder '%s' is already written" % builder.name) @@ -1032,41 +1252,63 @@ def write_link(self, **kwargs): parent_filename = os.path.abspath(parent.file.filename) if target_builder.source in (write_source, parent_filename): link_obj = SoftLink(path) - self.logger.debug(" Creating SoftLink '%s/%s' to '%s'" - % (parent.name, name, link_obj.path)) + self.logger.debug(" Creating SoftLink '%s/%s' to '%s'" % (parent.name, name, link_obj.path)) elif target_builder.source is not None: target_filename = os.path.abspath(target_builder.source) relative_path = os.path.relpath(target_filename, os.path.dirname(parent_filename)) if target_builder.location is not None: path = target_builder.location + "/" + target_builder.name link_obj = ExternalLink(relative_path, path) - self.logger.debug(" Creating ExternalLink '%s/%s' to '%s://%s'" - % (parent.name, name, link_obj.filename, link_obj.path)) + self.logger.debug( + " Creating ExternalLink '%s/%s' to '%s://%s'" % (parent.name, name, link_obj.filename, link_obj.path) + ) else: - msg = 'cannot create external link to %s' % path + msg = "cannot create external link to %s" % path raise ValueError(msg) parent[name] = link_obj self.__set_written(builder) return link_obj - @docval({'name': 'parent', 'type': Group, 'doc': 'the parent HDF5 object'}, # noqa: C901 - {'name': 'builder', 'type': DatasetBuilder, 'doc': 'the DatasetBuilder to write'}, - {'name': 'link_data', 'type': bool, - 'doc': 'If not specified otherwise link (True) or copy (False) HDF5 Datasets', 'default': True}, - {'name': 'exhaust_dci', 'type': bool, - 'doc': 'exhaust DataChunkIterators one at a time. If False, exhaust them concurrently', - 'default': True}, - {'name': 'export_source', 'type': str, - 'doc': 'The source of the builders when exporting', 'default': None}, - returns='the Dataset that was created', rtype=Dataset) + @docval( + { + "name": "parent", + "type": Group, + "doc": "the parent HDF5 object", + }, # noqa: C901 + { + "name": "builder", + "type": DatasetBuilder, + "doc": "the DatasetBuilder to write", + }, + { + "name": "link_data", + "type": bool, + "doc": "If not specified otherwise link (True) or copy (False) HDF5 Datasets", + "default": True, + }, + { + "name": "exhaust_dci", + "type": bool, + "doc": "exhaust DataChunkIterators one at a time. If False, exhaust them concurrently", + "default": True, + }, + { + "name": "export_source", + "type": str, + "doc": "The source of the builders when exporting", + "default": None, + }, + returns="the Dataset that was created", + rtype=Dataset, + ) def write_dataset(self, **kwargs): # noqa: C901 - """ Write a dataset to HDF5 + """Write a dataset to HDF5 The function uses other dataset-dependent write functions, e.g, ``__scalar_fill__``, ``__list_fill__``, and ``__setup_chunked_dset__`` to write the data. """ - parent, builder = popargs('parent', 'builder', kwargs) - link_data, exhaust_dci, export_source = getargs('link_data', 'exhaust_dci', 'export_source', kwargs) + parent, builder = popargs("parent", "builder", kwargs) + link_data, exhaust_dci, export_source = getargs("link_data", "exhaust_dci", "export_source", kwargs) self.logger.debug("Writing DatasetBuilder '%s' to parent group '%s'" % (builder.name, parent.name)) if self.get_written(builder): self.logger.debug(" DatasetBuilder '%s' is already written" % builder.name) @@ -1074,16 +1316,16 @@ def write_dataset(self, **kwargs): # noqa: C901 name = builder.name data = builder.data dataio = None - options = dict() # dict with additional + options = dict() # dict with additional if isinstance(data, H5DataIO): - options['io_settings'] = data.io_settings + options["io_settings"] = data.io_settings dataio = data link_data = data.link_data data = data.data else: - options['io_settings'] = {} + options["io_settings"] = {} attributes = builder.attributes - options['dtype'] = builder.dtype + options["dtype"] = builder.dtype dset = None link = None @@ -1096,12 +1338,13 @@ def write_dataset(self, **kwargs): # noqa: C901 if data_filename != parent_filename: # create external link to data relative_path = os.path.relpath(data_filename, os.path.dirname(parent_filename)) link = ExternalLink(relative_path, data.name) - self.logger.debug(" Creating ExternalLink '%s/%s' to '%s://%s'" - % (parent.name, name, link.filename, link.path)) + self.logger.debug( + " Creating ExternalLink '%s/%s' to '%s://%s'" + % (parent.name, name, link.filename, link.path) + ) else: # create soft link to dataset already in this file -- possible if mode == 'r+' link = SoftLink(data.name) - self.logger.debug(" Creating SoftLink '%s/%s' to '%s'" - % (parent.name, name, link.path)) + self.logger.debug(" Creating SoftLink '%s/%s' to '%s'" % (parent.name, name, link.path)) parent[name] = link else: # exporting export_source = os.path.abspath(export_source) @@ -1111,8 +1354,10 @@ def write_dataset(self, **kwargs): # noqa: C901 # to memory relative_path = os.path.relpath(data_filename, os.path.dirname(parent_filename)) link = ExternalLink(relative_path, data.name) - self.logger.debug(" Creating ExternalLink '%s/%s' to '%s://%s'" - % (parent.name, name, link.filename, link.path)) + self.logger.debug( + " Creating ExternalLink '%s/%s' to '%s://%s'" + % (parent.name, name, link.filename, link.path) + ) parent[name] = link elif parent.name != data.parent.name: # dataset is in export source and has different path # so create a soft link to the dataset in this file @@ -1120,59 +1365,72 @@ def write_dataset(self, **kwargs): # noqa: C901 # TODO check that there is/will be still a dataset at data.name -- if the dataset has # been removed, then this link will be broken link = SoftLink(data.name) - self.logger.debug(" Creating SoftLink '%s/%s' to '%s'" - % (parent.name, name, link.path)) + self.logger.debug(" Creating SoftLink '%s/%s' to '%s'" % (parent.name, name, link.path)) parent[name] = link else: # dataset is in export source and has same path as the builder, so copy the dataset - self.logger.debug(" Copying data from '%s://%s' to '%s/%s'" - % (data.file.filename, data.name, parent.name, name)) - parent.copy(source=data, - dest=parent, - name=name, - expand_soft=False, - expand_external=False, - expand_refs=False, - without_attrs=True) - dset = parent[name] - else: - # TODO add option for case where there are multiple links to the same dataset within a file: - # instead of copying the dset N times, copy it once and create soft links to it within the file - self.logger.debug(" Copying data from '%s://%s' to '%s/%s'" - % (data.file.filename, data.name, parent.name, name)) - parent.copy(source=data, + self.logger.debug( + " Copying data from '%s://%s' to '%s/%s'" + % (data.file.filename, data.name, parent.name, name) + ) + parent.copy( + source=data, dest=parent, name=name, expand_soft=False, expand_external=False, expand_refs=False, - without_attrs=True) + without_attrs=True, + ) + dset = parent[name] + else: + # TODO add option for case where there are multiple links to the same dataset within a file: + # instead of copying the dset N times, copy it once and create soft links to it within the file + self.logger.debug( + " Copying data from '%s://%s' to '%s/%s'" % (data.file.filename, data.name, parent.name, name) + ) + parent.copy( + source=data, + dest=parent, + name=name, + expand_soft=False, + expand_external=False, + expand_refs=False, + without_attrs=True, + ) dset = parent[name] # Write a compound dataset, i.e, a dataset with compound data type - elif isinstance(options['dtype'], list): + elif isinstance(options["dtype"], list): # do some stuff to figure out what data is a reference refs = list() - for i, dts in enumerate(options['dtype']): + for i, dts in enumerate(options["dtype"]): if self.__is_ref(dts): refs.append(i) # If one or more of the parts of the compound data type are references then we need to deal with those if len(refs) > 0: try: - _dtype = self.__resolve_dtype__(options['dtype'], data) + _dtype = self.__resolve_dtype__(options["dtype"], data) except Exception as exc: - msg = 'cannot add %s to %s - could not determine type' % (name, parent.name) + msg = "cannot add %s to %s - could not determine type" % (name, parent.name) raise Exception(msg) from exc - dset = parent.require_dataset(name, shape=(len(data),), dtype=_dtype, **options['io_settings']) + dset = parent.require_dataset( + name, + shape=(len(data),), + dtype=_dtype, + **options["io_settings"], + ) self.__set_written(builder) - self.logger.debug("Queueing reference resolution and set attribute on dataset '%s' containing " - "object references. attributes: %s" - % (name, list(attributes.keys()))) + self.logger.debug( + "Queueing reference resolution and set attribute on dataset '%s'" + " containing object references. attributes: %s" % (name, list(attributes.keys())) + ) @self.__queue_ref def _filler(): - self.logger.debug("Resolving object references and setting attribute on dataset '%s' " - "containing attributes: %s" - % (name, list(attributes.keys()))) + self.logger.debug( + "Resolving object references and setting attribute on dataset '%s' containing attributes: %s" + % (name, list(attributes.keys())) + ) ret = list() for item in data: new_item = list(item) @@ -1189,88 +1447,110 @@ def _filler(): dset = self.__list_fill__(parent, name, data, options) # Write a dataset containing references, i.e., a region or object reference. # NOTE: we can ignore options['io_settings'] for scalar data - elif self.__is_ref(options['dtype']): - _dtype = self.__dtypes.get(options['dtype']) + elif self.__is_ref(options["dtype"]): + _dtype = self.__dtypes.get(options["dtype"]) # Write a scalar data region reference dataset if isinstance(data, RegionBuilder): dset = parent.require_dataset(name, shape=(), dtype=_dtype) self.__set_written(builder) - self.logger.debug("Queueing reference resolution and set attribute on dataset '%s' containing a " - "region reference. attributes: %s" - % (name, list(attributes.keys()))) + self.logger.debug( + "Queueing reference resolution and set attribute on dataset '%s'" + " containing a region reference. attributes: %s" % (name, list(attributes.keys())) + ) @self.__queue_ref def _filler(): - self.logger.debug("Resolving region reference and setting attribute on dataset '%s' " - "containing attributes: %s" - % (name, list(attributes.keys()))) + self.logger.debug( + "Resolving region reference and setting attribute on dataset '%s' containing attributes: %s" + % (name, list(attributes.keys())) + ) ref = self.__get_ref(data.builder, data.region) dset = parent[name] dset[()] = ref self.set_attributes(dset, attributes) + # Write a scalar object reference dataset elif isinstance(data, ReferenceBuilder): dset = parent.require_dataset(name, dtype=_dtype, shape=()) self.__set_written(builder) - self.logger.debug("Queueing reference resolution and set attribute on dataset '%s' containing an " - "object reference. attributes: %s" - % (name, list(attributes.keys()))) + self.logger.debug( + "Queueing reference resolution and set attribute on dataset '%s'" + " containing an object reference. attributes: %s" % (name, list(attributes.keys())) + ) @self.__queue_ref def _filler(): - self.logger.debug("Resolving object reference and setting attribute on dataset '%s' " - "containing attributes: %s" - % (name, list(attributes.keys()))) + self.logger.debug( + "Resolving object reference and setting attribute on dataset '%s' containing attributes: %s" + % (name, list(attributes.keys())) + ) ref = self.__get_ref(data.builder) dset = parent[name] dset[()] = ref self.set_attributes(dset, attributes) + # Write an array dataset of references else: # Write a array of region references - if options['dtype'] == 'region': - dset = parent.require_dataset(name, dtype=_dtype, shape=(len(data),), **options['io_settings']) + if options["dtype"] == "region": + dset = parent.require_dataset( + name, + dtype=_dtype, + shape=(len(data),), + **options["io_settings"], + ) self.__set_written(builder) - self.logger.debug("Queueing reference resolution and set attribute on dataset '%s' containing " - "region references. attributes: %s" - % (name, list(attributes.keys()))) + self.logger.debug( + "Queueing reference resolution and set attribute on dataset" + " '%s' containing region references. attributes: %s" % (name, list(attributes.keys())) + ) @self.__queue_ref def _filler(): - self.logger.debug("Resolving region references and setting attribute on dataset '%s' " - "containing attributes: %s" - % (name, list(attributes.keys()))) + self.logger.debug( + "Resolving region references and setting attribute on" + " dataset '%s' containing attributes: %s" % (name, list(attributes.keys())) + ) refs = list() for item in data: refs.append(self.__get_ref(item.builder, item.region)) dset = parent[name] dset[()] = refs self.set_attributes(dset, attributes) + # Write array of object references else: - dset = parent.require_dataset(name, shape=(len(data),), dtype=_dtype, **options['io_settings']) + dset = parent.require_dataset( + name, + shape=(len(data),), + dtype=_dtype, + **options["io_settings"], + ) self.__set_written(builder) - self.logger.debug("Queueing reference resolution and set attribute on dataset '%s' containing " - "object references. attributes: %s" - % (name, list(attributes.keys()))) + self.logger.debug( + "Queueing reference resolution and set attribute on dataset" + " '%s' containing object references. attributes: %s" % (name, list(attributes.keys())) + ) @self.__queue_ref def _filler(): - self.logger.debug("Resolving object references and setting attribute on dataset '%s' " - "containing attributes: %s" - % (name, list(attributes.keys()))) + self.logger.debug( + "Resolving object references and setting attribute on" + " dataset '%s' containing attributes: %s" % (name, list(attributes.keys())) + ) refs = list() for item in data: refs.append(self.__get_ref(item)) dset = parent[name] dset[()] = refs self.set_attributes(dset, attributes) + return # write a "regular" dataset else: # Create an empty dataset if data is None: - dset = self.__setup_empty_dset__(parent, name, options['io_settings']) + dset = self.__setup_empty_dset__(parent, name, options["io_settings"]) dataio.dataset = dset # Write a scalar dataset containing a single string elif isinstance(data, (str, bytes)): @@ -1280,7 +1560,7 @@ def _filler(): dset = self.__setup_chunked_dset__(parent, name, data, options) self.__dci_queue.append(dataset=dset, data=data) # Write a regular in memory array (e.g., numpy array, list etc.) - elif hasattr(data, '__len__'): + elif hasattr(data, "__len__"): dset = self.__list_fill__(parent, name, data, options) # Write a regular scalar dataset else: @@ -1300,13 +1580,13 @@ def __scalar_fill__(cls, parent, name, data, options=None): dtype = None io_settings = {} if options is not None: - dtype = options.get('dtype') - io_settings = options.get('io_settings') + dtype = options.get("dtype") + io_settings = options.get("io_settings") if not isinstance(dtype, type): try: dtype = cls.__resolve_dtype__(dtype, data) except Exception as exc: - msg = 'cannot add %s to %s - could not determine type' % (name, parent.name) + msg = "cannot add %s to %s - could not determine type" % (name, parent.name) raise Exception(msg) from exc try: dset = parent.create_dataset(name, data=data, shape=None, dtype=dtype, **io_settings) @@ -1332,26 +1612,26 @@ def __setup_chunked_dset__(cls, parent, name, data, options=None): """ io_settings = {} if options is not None: - if 'io_settings' in options: - io_settings = options.get('io_settings') + if "io_settings" in options: + io_settings = options.get("io_settings") # Define the chunking options if the user has not set them explicitly. We need chunking for the iterative write. - if 'chunks' not in io_settings: + if "chunks" not in io_settings: recommended_chunks = data.recommended_chunk_shape() - io_settings['chunks'] = True if recommended_chunks is None else recommended_chunks + io_settings["chunks"] = True if recommended_chunks is None else recommended_chunks # Define the shape of the data if not provided by the user - if 'shape' not in io_settings: - io_settings['shape'] = data.recommended_data_shape() + if "shape" not in io_settings: + io_settings["shape"] = data.recommended_data_shape() # Define the maxshape of the data if not provided by the user - if 'maxshape' not in io_settings: - io_settings['maxshape'] = data.maxshape - if 'dtype' not in io_settings: - if (options is not None) and ('dtype' in options): - io_settings['dtype'] = options['dtype'] + if "maxshape" not in io_settings: + io_settings["maxshape"] = data.maxshape + if "dtype" not in io_settings: + if (options is not None) and ("dtype" in options): + io_settings["dtype"] = options["dtype"] else: - io_settings['dtype'] = data.dtype - if isinstance(io_settings['dtype'], str): + io_settings["dtype"] = data.dtype + if isinstance(io_settings["dtype"], str): # map to real dtype if we were given a string - io_settings['dtype'] = cls.__dtypes.get(io_settings['dtype']) + io_settings["dtype"] = cls.__dtypes.get(io_settings["dtype"]) try: dset = parent.create_dataset(name, **io_settings) except Exception as exc: @@ -1374,13 +1654,13 @@ def __setup_empty_dset__(cls, parent, name, io_settings): """ # Define the shape of the data if not provided by the user - if 'shape' not in io_settings: + if "shape" not in io_settings: raise ValueError(f"Cannot setup empty dataset {pp(parent.name, name)} without shape") - if 'dtype' not in io_settings: + if "dtype" not in io_settings: raise ValueError(f"Cannot setup empty dataset {pp(parent.name, name)} without dtype") - if isinstance(io_settings['dtype'], str): + if isinstance(io_settings["dtype"], str): # map to real dtype if we were given a string - io_settings['dtype'] = cls.__dtypes.get(io_settings['dtype']) + io_settings["dtype"] = cls.__dtypes.get(io_settings["dtype"]) try: dset = parent.create_dataset(name, **io_settings) except Exception as exc: @@ -1414,18 +1694,18 @@ def __list_fill__(cls, parent, name, data, options=None): io_settings = {} dtype = None if options is not None: - dtype = options.get('dtype') - io_settings = options.get('io_settings') + dtype = options.get("dtype") + io_settings = options.get("io_settings") if not isinstance(dtype, type): try: dtype = cls.__resolve_dtype__(dtype, data) except Exception as exc: - msg = 'cannot add %s to %s - could not determine type' % (name, parent.name) + msg = "cannot add %s to %s - could not determine type" % (name, parent.name) raise Exception(msg) from exc # define the data shape - if 'shape' in io_settings: - data_shape = io_settings.pop('shape') - elif hasattr(data, 'shape'): + if "shape" in io_settings: + data_shape = io_settings.pop("shape") + elif hasattr(data, "shape"): data_shape = data.shape elif isinstance(dtype, np.dtype): data_shape = (len(data),) @@ -1436,8 +1716,14 @@ def __list_fill__(cls, parent, name, data, options=None): try: dset = parent.create_dataset(name, shape=data_shape, dtype=dtype, **io_settings) except Exception as exc: - msg = "Could not create dataset %s in %s with shape %s, dtype %s, and iosettings %s. %s" % \ - (name, parent.name, str(data_shape), str(dtype), str(io_settings), str(exc)) + msg = "Could not create dataset %s in %s with shape %s, dtype %s, and iosettings %s. %s" % ( + name, + parent.name, + str(data_shape), + str(dtype), + str(io_settings), + str(exc), + ) raise Exception(msg) from exc # Write the data if len(data) > dset.shape[0]: @@ -1450,13 +1736,24 @@ def __list_fill__(cls, parent, name, data, options=None): raise e return dset - @docval({'name': 'container', 'type': (Builder, Container, ReferenceBuilder), 'doc': 'the object to reference', - 'default': None}, - {'name': 'region', 'type': (slice, list, tuple), 'doc': 'the region reference indexing object', - 'default': None}, - returns='the reference', rtype=Reference) + @docval( + { + "name": "container", + "type": (Builder, Container, ReferenceBuilder), + "doc": "the object to reference", + "default": None, + }, + { + "name": "region", + "type": (slice, list, tuple), + "doc": "the region reference indexing object", + "default": None, + }, + returns="the reference", + rtype=Reference, + ) def __get_ref(self, **kwargs): - container, region = getargs('container', 'region', kwargs) + container, region = getargs("container", "region", kwargs) if container is None: return None if isinstance(container, Builder): @@ -1478,7 +1775,7 @@ def __get_ref(self, **kwargs): if region is not None: dset = self.__file[path] if not isinstance(dset, Dataset): - raise ValueError('cannot create region reference without Dataset') + raise ValueError("cannot create region reference without Dataset") return self.__file[path].regionref[region] else: return self.__file[path].ref @@ -1489,13 +1786,13 @@ def __is_ref(self, dtype): if isinstance(dtype, RefSpec): return True if isinstance(dtype, dict): # may be dict from reading a compound dataset - return self.__is_ref(dtype['dtype']) + return self.__is_ref(dtype["dtype"]) if isinstance(dtype, str): return dtype == DatasetBuilder.OBJECT_REF_TYPE or dtype == DatasetBuilder.REGION_REF_TYPE return False def __queue_ref(self, func): - '''Set aside filling dset with references + """Set aside filling dset with references dest[sl] = func() @@ -1504,7 +1801,7 @@ def __queue_ref(self, func): sl: the np.s_ (slice) object for indexing into dset func: a function to call to return the chunk of data, with references filled in - ''' + """ # TODO: come up with more intelligent way of # queueing reference resolution, based on reference # dependency diff --git a/src/hdmf/backends/io.py b/src/hdmf/backends/io.py index 631185de5..22f82abb3 100644 --- a/src/hdmf/backends/io.py +++ b/src/hdmf/backends/io.py @@ -1,26 +1,37 @@ -from abc import ABCMeta, abstractmethod import os +from abc import ABCMeta, abstractmethod from pathlib import Path from ..build import BuildManager, GroupBuilder from ..container import Container -from .errors import UnsupportedOperation from ..utils import docval, getargs, popargs +from .errors import UnsupportedOperation class HDMFIO(metaclass=ABCMeta): - @docval({'name': 'manager', 'type': BuildManager, - 'doc': 'the BuildManager to use for I/O', 'default': None}, - {"name": "source", "type": (str, Path), - "doc": "the source of container being built i.e. file path", 'default': None}) + @docval( + { + "name": "manager", + "type": BuildManager, + "doc": "the BuildManager to use for I/O", + "default": None, + }, + { + "name": "source", + "type": (str, Path), + "doc": "the source of container being built i.e. file path", + "default": None, + }, + ) def __init__(self, **kwargs): - manager, source = getargs('manager', 'source', kwargs) + manager, source = getargs("manager", "source", kwargs) if isinstance(source, Path): source = source.resolve() - elif (isinstance(source, str) and - not (source.lower().startswith("http://") or - source.lower().startswith("https://") or - source.lower().startswith("s3://"))): + elif isinstance(source, str) and not ( + source.lower().startswith("http://") + or source.lower().startswith("https://") + or source.lower().startswith("s3://") + ): source = os.path.abspath(source) self.__manager = manager @@ -30,41 +41,62 @@ def __init__(self, **kwargs): @property def manager(self): - '''The BuildManager this instance is using''' + """The BuildManager this instance is using""" return self.__manager @property def source(self): - '''The source of the container being read/written i.e. file path''' + """The source of the container being read/written i.e. file path""" return self.__source - @docval(returns='the Container object that was read in', rtype=Container) + @docval(returns="the Container object that was read in", rtype=Container) def read(self, **kwargs): """Read a container from the IO source.""" f_builder = self.read_builder() if all(len(v) == 0 for v in f_builder.values()): # TODO also check that the keys are appropriate. print a better error message - raise UnsupportedOperation('Cannot build data. There are no values.') + raise UnsupportedOperation("Cannot build data. There are no values.") container = self.__manager.construct(f_builder) return container - @docval({'name': 'container', 'type': Container, 'doc': 'the Container object to write'}, - allow_extra=True) + @docval( + {"name": "container", "type": Container, "doc": "the Container object to write"}, + allow_extra=True, + ) def write(self, **kwargs): """Write a container to the IO source.""" - container = popargs('container', kwargs) + container = popargs("container", kwargs) f_builder = self.__manager.build(container, source=self.__source, root=True) self.write_builder(f_builder, **kwargs) - @docval({'name': 'src_io', 'type': 'HDMFIO', 'doc': 'the HDMFIO object for reading the data to export'}, - {'name': 'container', 'type': Container, - 'doc': ('the Container object to export. If None, then the entire contents of the HDMFIO object will be ' - 'exported'), - 'default': None}, - {'name': 'write_args', 'type': dict, 'doc': 'arguments to pass to :py:meth:`write_builder`', - 'default': dict()}, - {'name': 'clear_cache', 'type': bool, 'doc': 'whether to clear the build manager cache', - 'default': False}) + @docval( + { + "name": "src_io", + "type": "HDMFIO", + "doc": "the HDMFIO object for reading the data to export", + }, + { + "name": "container", + "type": Container, + "doc": ( + "the Container object to export. If None, then the entire contents of" + " the HDMFIO object will be exported" + ), + "default": None, + }, + { + "name": "write_args", + "type": dict, + "doc": "arguments to pass to :py:meth:`write_builder`", + "default": dict(), + }, + { + "name": "clear_cache", + "type": bool, + "doc": "whether to clear the build manager cache", + "default": False, + }, + ) def export(self, **kwargs): """Export from one backend to the backend represented by this class. @@ -92,7 +124,7 @@ def export(self, **kwargs): and LinkBuilder.builder.source are the same, and if so the link should be internal to the current file (even if the Builder.source points to a different location). """ - src_io, container, write_args, clear_cache = getargs('src_io', 'container', 'write_args', 'clear_cache', kwargs) + src_io, container, write_args, clear_cache = getargs("src_io", "container", "write_args", "clear_cache", kwargs) if container is None and clear_cache: # clear all containers and builders from cache so that they can all get rebuilt with export=True. # constructing the container is not efficient but there is no elegant way to trigger a @@ -101,14 +133,16 @@ def export(self, **kwargs): if container is not None: # check that manager exists, container was built from manager, and container is root of hierarchy if src_io.manager is None: - raise ValueError('When a container is provided, src_io must have a non-None manager (BuildManager) ' - 'property.') + raise ValueError( + "When a container is provided, src_io must have a non-None manager (BuildManager) property." + ) old_bldr = src_io.manager.get_builder(container) if old_bldr is None: - raise ValueError('The provided container must have been read by the provided src_io.') + raise ValueError("The provided container must have been read by the provided src_io.") if old_bldr.parent is not None: - raise ValueError('The provided container must be the root of the hierarchy of the ' - 'source used to read the container.') + raise ValueError( + "The provided container must be the root of the hierarchy of the source used to read the container." + ) # NOTE in HDF5IO, clear_cache is set to True when link_data is False if clear_cache: @@ -123,26 +157,28 @@ def export(self, **kwargs): self.write_builder(builder=bldr, **write_args) @abstractmethod - @docval(returns='a GroupBuilder representing the read data', rtype='GroupBuilder') + @docval(returns="a GroupBuilder representing the read data", rtype="GroupBuilder") def read_builder(self): - ''' Read data and return the GroupBuilder representing it ''' + """Read data and return the GroupBuilder representing it""" pass @abstractmethod - @docval({'name': 'builder', 'type': GroupBuilder, 'doc': 'the GroupBuilder object representing the Container'}, - allow_extra=True) + @docval( + {"name": "builder", "type": GroupBuilder, "doc": "the GroupBuilder object representing the Container"}, + allow_extra=True, + ) def write_builder(self, **kwargs): - ''' Write a GroupBuilder representing an Container object ''' + """Write a GroupBuilder representing an Container object""" pass @abstractmethod def open(self): - ''' Open this HDMFIO object for writing of the builder ''' + """Open this HDMFIO object for writing of the builder""" pass @abstractmethod def close(self): - ''' Close this HDMFIO object to further reading/writing''' + """Close this HDMFIO object to further reading/writing""" pass def __enter__(self): diff --git a/src/hdmf/backends/utils.py b/src/hdmf/backends/utils.py index 95eafe025..bed9ef57f 100644 --- a/src/hdmf/backends/utils.py +++ b/src/hdmf/backends/utils.py @@ -1,7 +1,8 @@ """Module with utility functions and classes used for implementation of I/O backends""" import os -from ..spec import NamespaceCatalog, GroupSpec, NamespaceBuilder -from ..utils import docval, popargs + +from ..spec import GroupSpec, NamespaceBuilder, NamespaceCatalog +from ..utils import docval, popargs class WriteStatusTracker(dict): @@ -9,6 +10,7 @@ class WriteStatusTracker(dict): Helper class used for tracking the write status of builders. I.e., to track whether a builder has been written or not. """ + def __init__(self): pass @@ -47,24 +49,29 @@ class NamespaceToBuilderHelper(object): """Helper class used in HDF5IO (and possibly elsewhere) to convert a namespace to a builder for I/O""" @classmethod - @docval({'name': 'ns_catalog', 'type': NamespaceCatalog, 'doc': 'the namespace catalog with the specs'}, - {'name': 'namespace', 'type': str, 'doc': 'the name of the namespace to be converted to a builder'}, - rtype=NamespaceBuilder) + @docval( + {"name": "ns_catalog", "type": NamespaceCatalog, "doc": "the namespace catalog with the specs"}, + {"name": "namespace", "type": str, "doc": "the name of the namespace to be converted to a builder"}, + rtype=NamespaceBuilder, + ) def convert_namespace(cls, **kwargs): """Convert a namespace to a builder""" - ns_catalog, namespace = popargs('ns_catalog', 'namespace', kwargs) + ns_catalog, namespace = popargs("ns_catalog", "namespace", kwargs) ns = ns_catalog.get_namespace(namespace) - builder = NamespaceBuilder(ns.doc, ns.name, - full_name=ns.full_name, - version=ns.version, - author=ns.author, - contact=ns.contact) + builder = NamespaceBuilder( + ns.doc, + ns.name, + full_name=ns.full_name, + version=ns.version, + author=ns.author, + contact=ns.contact, + ) for elem in ns.schema: - if 'namespace' in elem: - inc_ns = elem['namespace'] + if "namespace" in elem: + inc_ns = elem["namespace"] builder.include_namespace(inc_ns) else: - source = elem['source'] + source = elem["source"] for dt in ns_catalog.get_types(source): spec = ns_catalog.get_spec(namespace, dt) if spec.parent is not None: @@ -75,23 +82,31 @@ def convert_namespace(cls, **kwargs): return builder @classmethod - @docval({'name': 'source', 'type': str, 'doc': "source path"}) + @docval({"name": "source", "type": str, "doc": "source path"}) def get_source_name(self, source): return os.path.splitext(source)[0] @classmethod def __copy_spec(cls, spec): kwargs = dict() - kwargs['attributes'] = cls.__get_new_specs(spec.attributes, spec) - to_copy = ['doc', 'name', 'default_name', 'linkable', 'quantity', spec.inc_key(), spec.def_key()] + kwargs["attributes"] = cls.__get_new_specs(spec.attributes, spec) + to_copy = [ + "doc", + "name", + "default_name", + "linkable", + "quantity", + spec.inc_key(), + spec.def_key(), + ] if isinstance(spec, GroupSpec): - kwargs['datasets'] = cls.__get_new_specs(spec.datasets, spec) - kwargs['groups'] = cls.__get_new_specs(spec.groups, spec) - kwargs['links'] = cls.__get_new_specs(spec.links, spec) + kwargs["datasets"] = cls.__get_new_specs(spec.datasets, spec) + kwargs["groups"] = cls.__get_new_specs(spec.groups, spec) + kwargs["links"] = cls.__get_new_specs(spec.links, spec) else: - to_copy.append('dtype') - to_copy.append('shape') - to_copy.append('dims') + to_copy.append("dtype") + to_copy.append("shape") + to_copy.append("dims") for key in to_copy: val = getattr(spec, key) if val is not None: diff --git a/src/hdmf/backends/warnings.py b/src/hdmf/backends/warnings.py index 77a711584..8151cfc60 100644 --- a/src/hdmf/backends/warnings.py +++ b/src/hdmf/backends/warnings.py @@ -2,4 +2,5 @@ class BrokenLinkWarning(UserWarning): """ Raised when a group has a key with a None value. """ + pass diff --git a/src/hdmf/build/__init__.py b/src/hdmf/build/__init__.py index ea5d21152..e299d222f 100644 --- a/src/hdmf/build/__init__.py +++ b/src/hdmf/build/__init__.py @@ -1,8 +1,26 @@ -from .builders import Builder, DatasetBuilder, GroupBuilder, LinkBuilder, ReferenceBuilder, RegionBuilder +from .builders import ( + Builder, + DatasetBuilder, + GroupBuilder, + LinkBuilder, + ReferenceBuilder, + RegionBuilder, +) from .classgenerator import CustomClassGenerator, MCIClassGenerator -from .errors import (BuildError, OrphanContainerBuildError, ReferenceTargetNotBuiltError, ContainerConfigurationError, - ConstructError) +from .errors import ( + BuildError, + OrphanContainerBuildError, + ReferenceTargetNotBuiltError, + ContainerConfigurationError, + ConstructError, +) from .manager import BuildManager, TypeMap from .objectmapper import ObjectMapper -from .warnings import (BuildWarning, MissingRequiredBuildWarning, DtypeConversionWarning, IncorrectQuantityBuildWarning, - MissingRequiredWarning, OrphanContainerWarning) +from .warnings import ( + BuildWarning, + MissingRequiredBuildWarning, + DtypeConversionWarning, + IncorrectQuantityBuildWarning, + MissingRequiredWarning, + OrphanContainerWarning, +) diff --git a/src/hdmf/build/builders.py b/src/hdmf/build/builders.py index f96e6016a..da3f5f6bb 100644 --- a/src/hdmf/build/builders.py +++ b/src/hdmf/build/builders.py @@ -12,13 +12,27 @@ class Builder(dict, metaclass=ABCMeta): - - @docval({'name': 'name', 'type': str, 'doc': 'the name of the group'}, - {'name': 'parent', 'type': 'Builder', 'doc': 'the parent builder of this Builder', 'default': None}, - {'name': 'source', 'type': str, - 'doc': 'the source of the data in this builder e.g. file name', 'default': None}) + @docval( + { + "name": "name", + "type": str, + "doc": "the name of the Builder", + }, + { + "name": "parent", + "type": "Builder", + "doc": "the parent builder of this Builder", + "default": None, + }, + { + "name": "source", + "type": str, + "doc": "the source of the data in this builder e.g. file name", + "default": None, + }, + ) def __init__(self, **kwargs): - name, parent, source = getargs('name', 'parent', 'source', kwargs) + name, parent, source = getargs("name", "parent", "source", kwargs) super().__init__() self.__name = name self.__parent = parent @@ -52,7 +66,7 @@ def source(self): @source.setter def source(self, s): if self.__source is not None: - raise AttributeError('Cannot overwrite source.') + raise AttributeError("Cannot overwrite source.") self.__source = s @property @@ -63,7 +77,7 @@ def parent(self): @parent.setter def parent(self, p): if self.__parent is not None: - raise AttributeError('Cannot overwrite parent.') + raise AttributeError("Cannot overwrite parent.") self.__parent = p if self.__source is None: self.source = p.source @@ -74,16 +88,35 @@ def __repr__(self): class BaseBuilder(Builder, metaclass=ABCMeta): - __attribute = 'attributes' # self dictionary key for attributes - - @docval({'name': 'name', 'type': str, 'doc': 'The name of the builder.'}, - {'name': 'attributes', 'type': dict, 'doc': 'A dictionary of attributes to create in this builder.', - 'default': dict()}, - {'name': 'parent', 'type': 'GroupBuilder', 'doc': 'The parent builder of this builder.', 'default': None}, - {'name': 'source', 'type': str, - 'doc': 'The source of the data represented in this builder', 'default': None}) + __attribute = "attributes" # self dictionary key for attributes + + @docval( + { + "name": "name", + "type": str, + "doc": "The name of the builder.", + }, + { + "name": "attributes", + "type": dict, + "doc": "A dictionary of attributes to create in this builder.", + "default": dict(), + }, + { + "name": "parent", + "type": "GroupBuilder", + "doc": "The parent builder of this builder.", + "default": None, + }, + { + "name": "source", + "type": str, + "doc": "The source of the data represented in this builder", + "default": None, + }, + ) def __init__(self, **kwargs): - name, attributes, parent, source = getargs('name', 'attributes', 'parent', 'source', kwargs) + name, attributes, parent, source = getargs("name", "attributes", "parent", "source", kwargs) super().__init__(name, parent, source) super().__setitem__(BaseBuilder.__attribute, dict()) for name, val in attributes.items(): @@ -104,43 +137,78 @@ def attributes(self): """The attributes stored in this Builder object.""" return super().__getitem__(BaseBuilder.__attribute) - @docval({'name': 'name', 'type': str, 'doc': 'The name of the attribute.'}, - {'name': 'value', 'type': None, 'doc': 'The attribute value.'}) + @docval( + {"name": "name", "type": str, "doc": "The name of the attribute."}, + {"name": "value", "type": None, "doc": "The attribute value."}, + ) def set_attribute(self, **kwargs): """Set an attribute for this group.""" - name, value = getargs('name', 'value', kwargs) + name, value = getargs("name", "value", kwargs) self.attributes[name] = value class GroupBuilder(BaseBuilder): # sub-dictionary keys. subgroups go in super().__getitem__(GroupBuilder.__group) - __group = 'groups' - __dataset = 'datasets' - __link = 'links' - __attribute = 'attributes' - - @docval({'name': 'name', 'type': str, 'doc': 'The name of the group.'}, - {'name': 'groups', 'type': (dict, list), - 'doc': ('A dictionary or list of subgroups to add to this group. If a dict is provided, only the ' - 'values are used.'), - 'default': dict()}, - {'name': 'datasets', 'type': (dict, list), - 'doc': ('A dictionary or list of datasets to add to this group. If a dict is provided, only the ' - 'values are used.'), - 'default': dict()}, - {'name': 'attributes', 'type': dict, 'doc': 'A dictionary of attributes to create in this group.', - 'default': dict()}, - {'name': 'links', 'type': (dict, list), - 'doc': ('A dictionary or list of links to add to this group. If a dict is provided, only the ' - 'values are used.'), - 'default': dict()}, - {'name': 'parent', 'type': 'GroupBuilder', 'doc': 'The parent builder of this builder.', 'default': None}, - {'name': 'source', 'type': str, - 'doc': 'The source of the data represented in this builder.', 'default': None}) + __group = "groups" + __dataset = "datasets" + __link = "links" + __attribute = "attributes" + + @docval( + { + "name": "name", + "type": str, + "doc": "The name of the group.", + }, + { + "name": "groups", + "type": (dict, list), + "doc": ( + "A dictionary or list of subgroups to add to this group. If a dict is provided," + " only the values are used." + ), + "default": dict(), + }, + { + "name": "datasets", + "type": (dict, list), + "doc": ( + "A dictionary or list of datasets to add to this group. If a dict is provided, only" + " the values are used." + ), + "default": dict(), + }, + { + "name": "attributes", + "type": dict, + "doc": "A dictionary of attributes to create in this group.", + "default": dict(), + }, + { + "name": "links", + "type": (dict, list), + "doc": ( + "A dictionary or list of links to add to this group. If a dict is provided, only the values are used." + ), + "default": dict(), + }, + { + "name": "parent", + "type": "GroupBuilder", + "doc": "The parent builder of this builder.", + "default": None, + }, + { + "name": "source", + "type": str, + "doc": "The source of the data represented in this builder.", + "default": None, + }, + ) def __init__(self, **kwargs): """Create a builder object for a group.""" - name, groups, datasets, links, attributes, parent, source = getargs( - 'name', 'groups', 'datasets', 'links', 'attributes', 'parent', 'source', kwargs) + name, groups, datasets, links = getargs("name", "groups", "datasets", "links", kwargs) + attributes, parent, source = getargs("attributes", "parent", "source", kwargs) # NOTE: if groups, datasets, or links are dicts, their keys are unused groups = self.__to_list(groups) datasets = self.__to_list(datasets) @@ -167,7 +235,7 @@ def __to_list(self, d): @property def source(self): - ''' The source of this Builder ''' + """The source of this Builder""" return super().source @source.setter @@ -202,7 +270,7 @@ def links(self): @docval(*get_docval(BaseBuilder.set_attribute)) def set_attribute(self, **kwargs): """Set an attribute for this group.""" - name, value = getargs('name', 'value', kwargs) + name, value = getargs("name", "value", kwargs) self.__check_obj_type(name, GroupBuilder.__attribute) super().set_attribute(name, value) self.obj_type[name] = GroupBuilder.__attribute @@ -210,25 +278,26 @@ def set_attribute(self, **kwargs): def __check_obj_type(self, name, obj_type): # check that the name is not associated with a different object type in this group if name in self.obj_type and self.obj_type[name] != obj_type: - raise ValueError("'%s' already exists in %s.%s, cannot set in %s." - % (name, self.name, self.obj_type[name], obj_type)) + raise ValueError( + "'%s' already exists in %s.%s, cannot set in %s." % (name, self.name, self.obj_type[name], obj_type) + ) - @docval({'name': 'builder', 'type': 'GroupBuilder', 'doc': 'The GroupBuilder to add to this group.'}) + @docval({"name": "builder", "type": "GroupBuilder", "doc": "The GroupBuilder to add to this group."}) def set_group(self, **kwargs): """Add a subgroup to this group.""" - builder = getargs('builder', kwargs) + builder = getargs("builder", kwargs) self.__set_builder(builder, GroupBuilder.__group) - @docval({'name': 'builder', 'type': 'DatasetBuilder', 'doc': 'The DatasetBuilder to add to this group.'}) + @docval({"name": "builder", "type": "DatasetBuilder", "doc": "The DatasetBuilder to add to this group."}) def set_dataset(self, **kwargs): """Add a dataset to this group.""" - builder = getargs('builder', kwargs) + builder = getargs("builder", kwargs) self.__set_builder(builder, GroupBuilder.__dataset) - @docval({'name': 'builder', 'type': 'LinkBuilder', 'doc': 'The LinkBuilder to add to this group.'}) + @docval({"name": "builder", "type": "LinkBuilder", "doc": "The LinkBuilder to add to this group."}) def set_link(self, **kwargs): """Add a link to this group.""" - builder = getargs('builder', kwargs) + builder = getargs("builder", kwargs) self.__set_builder(builder, GroupBuilder.__link) def __set_builder(self, builder, obj_type): @@ -258,7 +327,7 @@ def __getitem__(self, key): Key can be a posix path to a sub-builder. """ try: - key_ar = _posixpath.normpath(key).split('/') + key_ar = _posixpath.normpath(key).split("/") return self.__get_rec(key_ar) except KeyError: raise KeyError(key) @@ -268,7 +337,7 @@ def get(self, key, default=None): Key can be a posix path to a sub-builder. """ try: - key_ar = _posixpath.normpath(key).split('/') + key_ar = _posixpath.normpath(key).split("/") return self.__get_rec(key_ar) except KeyError: return default @@ -285,57 +354,107 @@ def __get_rec(self, key_ar): raise KeyError(key_ar[0]) def __setitem__(self, args, val): - raise NotImplementedError('__setitem__') + raise NotImplementedError("__setitem__") def __contains__(self, item): return self.obj_type.__contains__(item) def items(self): """Like dict.items, but iterates over items in groups, datasets, attributes, and links sub-dictionaries.""" - return _itertools.chain(self.groups.items(), - self.datasets.items(), - self.attributes.items(), - self.links.items()) + return _itertools.chain( + self.groups.items(), + self.datasets.items(), + self.attributes.items(), + self.links.items(), + ) def keys(self): """Like dict.keys, but iterates over keys in groups, datasets, attributes, and links sub-dictionaries.""" - return _itertools.chain(self.groups.keys(), - self.datasets.keys(), - self.attributes.keys(), - self.links.keys()) + return _itertools.chain( + self.groups.keys(), + self.datasets.keys(), + self.attributes.keys(), + self.links.keys(), + ) def values(self): """Like dict.values, but iterates over values in groups, datasets, attributes, and links sub-dictionaries.""" - return _itertools.chain(self.groups.values(), - self.datasets.values(), - self.attributes.values(), - self.links.values()) + return _itertools.chain( + self.groups.values(), + self.datasets.values(), + self.attributes.values(), + self.links.values(), + ) class DatasetBuilder(BaseBuilder): - OBJECT_REF_TYPE = 'object' - REGION_REF_TYPE = 'region' - - @docval({'name': 'name', 'type': str, 'doc': 'The name of the dataset.'}, - {'name': 'data', - 'type': ('array_data', 'scalar_data', 'data', 'DatasetBuilder', 'RegionBuilder', Iterable, datetime), - 'doc': 'The data in this dataset.', 'default': None}, - {'name': 'dtype', 'type': (type, np.dtype, str, list), - 'doc': 'The datatype of this dataset.', 'default': None}, - {'name': 'attributes', 'type': dict, - 'doc': 'A dictionary of attributes to create in this dataset.', 'default': dict()}, - {'name': 'maxshape', 'type': (int, tuple), - 'doc': 'The shape of this dataset. Use None for scalars.', 'default': None}, - {'name': 'chunks', 'type': bool, 'doc': 'Whether or not to chunk this dataset.', 'default': False}, - {'name': 'parent', 'type': GroupBuilder, 'doc': 'The parent builder of this builder.', 'default': None}, - {'name': 'source', 'type': str, 'doc': 'The source of the data in this builder.', 'default': None}) + OBJECT_REF_TYPE = "object" + REGION_REF_TYPE = "region" + + @docval( + { + "name": "name", + "type": str, + "doc": "The name of the dataset.", + }, + { + "name": "data", + "type": ( + "array_data", + "scalar_data", + "data", + "DatasetBuilder", + "RegionBuilder", + Iterable, + datetime, + ), + "doc": "The data in this dataset.", + "default": None, + }, + { + "name": "dtype", + "type": (type, np.dtype, str, list), + "doc": "The datatype of this dataset.", + "default": None, + }, + { + "name": "attributes", + "type": dict, + "doc": "A dictionary of attributes to create in this dataset.", + "default": dict(), + }, + { + "name": "maxshape", + "type": (int, tuple), + "doc": "The shape of this dataset. Use None for scalars.", + "default": None, + }, + { + "name": "chunks", + "type": bool, + "doc": "Whether or not to chunk this dataset.", + "default": False, + }, + { + "name": "parent", + "type": GroupBuilder, + "doc": "The parent builder of this builder.", + "default": None, + }, + { + "name": "source", + "type": str, + "doc": "The source of the data in this builder.", + "default": None, + }, + ) def __init__(self, **kwargs): - """ Create a Builder object for a dataset """ - name, data, dtype, attributes, maxshape, chunks, parent, source = getargs( - 'name', 'data', 'dtype', 'attributes', 'maxshape', 'chunks', 'parent', 'source', kwargs) + """Create a Builder object for a dataset""" + name, data, dtype, attributes = getargs("name", "data", "dtype", "attributes", kwargs) + maxshape, chunks, parent, source = getargs("maxshape", "chunks", "parent", "source", kwargs) super().__init__(name, attributes, parent, source) - self['data'] = data - self['attributes'] = _copy.copy(attributes) + self["data"] = data + self["attributes"] = _copy.copy(attributes) self.__chunks = chunks self.__maxshape = maxshape if isinstance(data, BaseBuilder): @@ -347,13 +466,13 @@ def __init__(self, **kwargs): @property def data(self): """The data stored in the dataset represented by this builder.""" - return self['data'] + return self["data"] @data.setter def data(self, val): - if self['data'] is not None: + if self["data"] is not None: raise AttributeError("Cannot overwrite data.") - self['data'] = val + self["data"] = val @property def chunks(self): @@ -378,53 +497,84 @@ def dtype(self, val): class LinkBuilder(Builder): - - @docval({'name': 'builder', 'type': (DatasetBuilder, GroupBuilder), - 'doc': 'The target group or dataset of this link.'}, - {'name': 'name', 'type': str, 'doc': 'The name of the link', 'default': None}, - {'name': 'parent', 'type': GroupBuilder, 'doc': 'The parent builder of this builder', 'default': None}, - {'name': 'source', 'type': str, 'doc': 'The source of the data in this builder', 'default': None}) + @docval( + { + "name": "builder", + "type": (DatasetBuilder, GroupBuilder), + "doc": "The target group or dataset of this link.", + }, + { + "name": "name", + "type": str, + "doc": "The name of the link", + "default": None, + }, + { + "name": "parent", + "type": GroupBuilder, + "doc": "The parent builder of this builder", + "default": None, + }, + { + "name": "source", + "type": str, + "doc": "The source of the data in this builder", + "default": None, + }, + ) def __init__(self, **kwargs): """Create a builder object for a link.""" - name, builder, parent, source = getargs('name', 'builder', 'parent', 'source', kwargs) + name, builder, parent, source = getargs("name", "builder", "parent", "source", kwargs) if name is None: name = builder.name super().__init__(name, parent, source) - self['builder'] = builder + self["builder"] = builder @property def builder(self): """The target builder object.""" - return self['builder'] + return self["builder"] class ReferenceBuilder(dict): - - @docval({'name': 'builder', 'type': (DatasetBuilder, GroupBuilder), - 'doc': 'The group or dataset this reference applies to.'}) + @docval( + { + "name": "builder", + "type": (DatasetBuilder, GroupBuilder), + "doc": "The group or dataset this reference applies to.", + } + ) def __init__(self, **kwargs): """Create a builder object for a reference.""" - builder = getargs('builder', kwargs) - self['builder'] = builder + builder = getargs("builder", kwargs) + self["builder"] = builder @property def builder(self): """The target builder object.""" - return self['builder'] + return self["builder"] class RegionBuilder(ReferenceBuilder): - - @docval({'name': 'region', 'type': (slice, tuple, list, RegionReference), - 'doc': 'The region, i.e. slice or indices, into the target dataset.'}, - {'name': 'builder', 'type': DatasetBuilder, 'doc': 'The dataset this region reference applies to.'}) + @docval( + { + "name": "region", + "type": (slice, tuple, list, RegionReference), + "doc": "The region, i.e. slice or indices, into the target dataset.", + }, + { + "name": "builder", + "type": DatasetBuilder, + "doc": "The dataset this region reference applies to.", + }, + ) def __init__(self, **kwargs): """Create a builder object for a region reference.""" - region, builder = getargs('region', 'builder', kwargs) + region, builder = getargs("region", "builder", kwargs) super().__init__(builder) - self['region'] = region + self["region"] = region @property def region(self): """The selected region of the target dataset.""" - return self['region'] + return self["region"] diff --git a/src/hdmf/build/classgenerator.py b/src/hdmf/build/classgenerator.py index 113277168..8684d92d7 100644 --- a/src/hdmf/build/classgenerator.py +++ b/src/hdmf/build/classgenerator.py @@ -10,7 +10,6 @@ class ClassGenerator: - def __init__(self): self.__custom_generators = [] @@ -18,36 +17,39 @@ def __init__(self): def custom_generators(self): return self.__custom_generators - @docval({'name': 'generator', 'type': type, 'doc': 'the CustomClassGenerator class to register'}) + @docval({"name": "generator", "type": type, "doc": "the CustomClassGenerator class to register"}) def register_generator(self, **kwargs): """Add a custom class generator to this ClassGenerator. Generators added later are run first. Duplicates are moved to the top of the list. """ - generator = getargs('generator', kwargs) + generator = getargs("generator", kwargs) if not issubclass(generator, CustomClassGenerator): - raise ValueError('Generator %s must be a subclass of CustomClassGenerator.' % generator) + raise ValueError("Generator %s must be a subclass of CustomClassGenerator." % generator) if generator in self.__custom_generators: self.__custom_generators.remove(generator) self.__custom_generators.insert(0, generator) - @docval({'name': 'data_type', 'type': str, 'doc': 'the data type to create a AbstractContainer class for'}, - {'name': 'spec', 'type': BaseStorageSpec, 'doc': ''}, - {'name': 'parent_cls', 'type': type, 'doc': ''}, - {'name': 'attr_names', 'type': dict, 'doc': ''}, - {'name': 'type_map', 'type': 'TypeMap', 'doc': ''}, - returns='the class for the given namespace and data_type', rtype=type) + @docval( + {"name": "data_type", "type": str, "doc": "the data type to create a AbstractContainer class for"}, + {"name": "spec", "type": BaseStorageSpec, "doc": ""}, + {"name": "parent_cls", "type": type, "doc": ""}, + {"name": "attr_names", "type": dict, "doc": ""}, + {"name": "type_map", "type": "TypeMap", "doc": ""}, + returns="the class for the given namespace and data_type", + rtype=type, + ) def generate_class(self, **kwargs): """Get the container class from data type specification. If no class has been associated with the ``data_type`` from ``namespace``, a class will be dynamically created and returned. """ - data_type, spec, parent_cls, attr_names, type_map = getargs('data_type', 'spec', 'parent_cls', 'attr_names', - 'type_map', kwargs) + data_type, spec, parent_cls = getargs("data_type", "spec", "parent_cls", kwargs) + attr_names, type_map = getargs("attr_names", "type_map", kwargs) not_inherited_fields = dict() for k, field_spec in attr_names.items(): - if k == 'help': # pragma: no cover + if k == "help": # pragma: no cover # (legacy) do not add field named 'help' to any part of class object continue if isinstance(field_spec, GroupSpec) and field_spec.data_type is None: # skip named, untyped groups @@ -62,8 +64,15 @@ def generate_class(self, **kwargs): for class_generator in self.__custom_generators: # pragma: no branch # each generator can update classdict and docval_args if class_generator.apply_generator_to_field(field_spec, bases, type_map): - class_generator.process_field_spec(classdict, docval_args, parent_cls, attr_name, - not_inherited_fields, type_map, spec) + class_generator.process_field_spec( + classdict, + docval_args, + parent_cls, + attr_name, + not_inherited_fields, + type_map, + spec, + ) break # each field_spec should be processed by only one generator for class_generator in self.__custom_generators: @@ -77,10 +86,12 @@ def generate_class(self, **kwargs): # this error should never happen after hdmf#322 name = spec.data_type_def if name is None: - name = 'Unknown' - raise ValueError("Cannot dynamically generate class for type '%s'. " % name - + str(e) - + " Please define that type before defining '%s'." % name) + name = "Unknown" + raise ValueError( + "Cannot dynamically generate class for type '%s'. " % name + + str(e) + + " Please define that type before defining '%s'." % name + ) cls = ExtenderMeta(data_type, tuple(bases), classdict) return cls @@ -93,7 +104,7 @@ class CustomClassGenerator: """Subclass this class and register an instance to alter how classes are auto-generated.""" def __new__(cls, *args, **kwargs): # pragma: no cover - raise TypeError('Cannot instantiate class %s' % cls.__name__) + raise TypeError("Cannot instantiate class %s" % cls.__name__) # mapping from spec types to allowable python types for docval for fields during dynamic class generation # e.g., if a dataset/attribute spec has dtype int32, then get_class should generate a docval for the class' @@ -102,32 +113,44 @@ def __new__(cls, *args, **kwargs): # pragma: no cover # passing an int64 to __init__ would result in the field storing the value as an int64 (and subsequently written # as an int64). no upconversion or downconversion happens as a result of this map _spec_dtype_map = { - 'float32': (float, np.float32, np.float64), - 'float': (float, np.float32, np.float64), - 'float64': (float, np.float64), - 'double': (float, np.float64), - 'int8': (np.int8, np.int16, np.int32, np.int64, int), - 'int16': (np.int16, np.int32, np.int64, int), - 'short': (np.int16, np.int32, np.int64, int), - 'int32': (int, np.int32, np.int64), - 'int': (int, np.int32, np.int64), - 'int64': np.int64, - 'long': np.int64, - 'uint8': (np.uint8, np.uint16, np.uint32, np.uint64), - 'uint16': (np.uint16, np.uint32, np.uint64), - 'uint32': (np.uint32, np.uint64), - 'uint64': np.uint64, - 'numeric': (float, np.float32, np.float64, np.int8, np.int16, np.int32, np.int64, int, np.uint8, np.uint16, - np.uint32, np.uint64), - 'text': str, - 'utf': str, - 'utf8': str, - 'utf-8': str, - 'ascii': bytes, - 'bytes': bytes, - 'bool': (bool, np.bool_), - 'isodatetime': datetime, - 'datetime': datetime + "float32": (float, np.float32, np.float64), + "float": (float, np.float32, np.float64), + "float64": (float, np.float64), + "double": (float, np.float64), + "int8": (np.int8, np.int16, np.int32, np.int64, int), + "int16": (np.int16, np.int32, np.int64, int), + "short": (np.int16, np.int32, np.int64, int), + "int32": (int, np.int32, np.int64), + "int": (int, np.int32, np.int64), + "int64": np.int64, + "long": np.int64, + "uint8": (np.uint8, np.uint16, np.uint32, np.uint64), + "uint16": (np.uint16, np.uint32, np.uint64), + "uint32": (np.uint32, np.uint64), + "uint64": np.uint64, + "numeric": ( + float, + np.float32, + np.float64, + np.int8, + np.int16, + np.int32, + np.int64, + int, + np.uint8, + np.uint16, + np.uint32, + np.uint64, + ), + "text": str, + "utf": str, + "utf8": str, + "utf-8": str, + "ascii": bytes, + "bytes": bytes, + "bool": (bool, np.bool_), + "isodatetime": datetime, + "datetime": datetime, } @classmethod @@ -172,14 +195,14 @@ def _get_type(cls, spec, type_map): elif spec.shape is None and spec.dims is None: return cls._get_type_from_spec_dtype(spec.dtype) else: - return 'array_data', 'data' + return "array_data", "data" if isinstance(spec, LinkSpec): return cls._get_container_type(spec.target_type, type_map) if spec.data_type is not None: return cls._get_container_type(spec.data_type, type_map) if spec.shape is None and spec.dims is None: return cls._get_type_from_spec_dtype(spec.dtype) - return 'array_data', 'data' + return "array_data", "data" @classmethod def _ischild(cls, dtype): @@ -197,8 +220,8 @@ def _set_default_name(docval_args, default_name): """Set the default value for the name docval argument.""" if default_name is not None: for x in docval_args: - if x['name'] == 'name': - x['default'] = default_name + if x["name"] == "name": + x["default"] = default_name @classmethod def apply_generator_to_field(cls, field_spec, bases, type_map): @@ -206,7 +229,16 @@ def apply_generator_to_field(cls, field_spec, bases, type_map): return True @classmethod - def process_field_spec(cls, classdict, docval_args, parent_cls, attr_name, not_inherited_fields, type_map, spec): + def process_field_spec( + cls, + classdict, + docval_args, + parent_cls, + attr_name, + not_inherited_fields, + type_map, + spec, + ): """Add __fields__ to the classdict and update the docval args for the field spec with the given attribute name. :param classdict: The dict to update with __fields__ (or a different parent_cls._fieldsname). :param docval_args: The list of docval arguments. @@ -218,24 +250,19 @@ def process_field_spec(cls, classdict, docval_args, parent_cls, attr_name, not_i """ field_spec = not_inherited_fields[attr_name] dtype = cls._get_type(field_spec, type_map) - fields_conf = {'name': attr_name, - 'doc': field_spec['doc']} + fields_conf = {"name": attr_name, "doc": field_spec["doc"]} if cls._ischild(dtype) and issubclass(parent_cls, Container) and not isinstance(field_spec, LinkSpec): - fields_conf['child'] = True + fields_conf["child"] = True # if getattr(field_spec, 'value', None) is not None: # TODO set the fixed value on the class? # fields_conf['settable'] = False classdict.setdefault(parent_cls._fieldsname, list()).append(fields_conf) - docval_arg = dict( - name=attr_name, - doc=field_spec.doc, - type=cls._get_type(field_spec, type_map) - ) - shape = getattr(field_spec, 'shape', None) + docval_arg = dict(name=attr_name, doc=field_spec.doc, type=cls._get_type(field_spec, type_map)) + shape = getattr(field_spec, "shape", None) if shape is not None: - docval_arg['shape'] = shape + docval_arg["shape"] = shape if cls._check_spec_optional(field_spec, spec): - docval_arg['default'] = getattr(field_spec, 'default_value', None) + docval_arg["default"] = getattr(field_spec, "default_value", None) cls._add_to_docval_args(docval_args, docval_arg) @classmethod @@ -253,7 +280,7 @@ def _add_to_docval_args(cls, docval_args, arg, err_if_present=False): """Add the docval arg to the list if not present. If present, overwrite it in place or raise an error.""" inserted = False for i, x in enumerate(docval_args): - if x['name'] == arg['name']: + if x["name"] == arg["name"]: if err_if_present: raise ValueError("Argument %s already exists in docval args." % arg["name"]) docval_args[i] = arg @@ -279,7 +306,7 @@ def post_process(cls, classdict, bases, docval_args, spec): # be passed for a name positional or keyword arg if spec.name is not None: for arg in list(docval_args): - if arg['name'] == 'name': + if arg["name"] == "name": docval_args.remove(arg) # set default name in docval args if provided @@ -289,7 +316,7 @@ def post_process(cls, classdict, bases, docval_args, spec): def set_init(cls, classdict, bases, docval_args, not_inherited_fields, name): # get docval arg names from superclass base = bases[0] - parent_docval_args = set(arg['name'] for arg in get_docval(base.__init__)) + parent_docval_args = set(arg["name"] for arg in get_docval(base.__init__)) new_args = list() for attr_name, field_spec in not_inherited_fields.items(): # store arguments for fields that are not in the superclass and not in the superclass __init__ docval @@ -319,18 +346,26 @@ def __init__(self, **kwargs): for f, arg_val in new_kwargs.items(): setattr(self, f, arg_val) - classdict['__init__'] = __init__ + classdict["__init__"] = __init__ class MCIClassGenerator(CustomClassGenerator): - @classmethod def apply_generator_to_field(cls, field_spec, bases, type_map): """Return True if the field spec has quantity * or +, False otherwise.""" - return getattr(field_spec, 'quantity', None) in (ZERO_OR_MANY, ONE_OR_MANY) + return getattr(field_spec, "quantity", None) in (ZERO_OR_MANY, ONE_OR_MANY) @classmethod - def process_field_spec(cls, classdict, docval_args, parent_cls, attr_name, not_inherited_fields, type_map, spec): + def process_field_spec( + cls, + classdict, + docval_args, + parent_cls, + attr_name, + not_inherited_fields, + type_map, + spec, + ): """Add __clsconf__ to the classdict and update the docval args for the field spec with the given attribute name. :param classdict: The dict to update with __clsconf__. :param docval_args: The list of docval arguments. @@ -344,20 +379,20 @@ def process_field_spec(cls, classdict, docval_args, parent_cls, attr_name, not_i field_clsconf = dict( attr=attr_name, type=cls._get_type(field_spec, type_map), - add='add_{}'.format(attr_name), - get='get_{}'.format(attr_name), - create='create_{}'.format(attr_name) + add="add_{}".format(attr_name), + get="get_{}".format(attr_name), + create="create_{}".format(attr_name), ) - classdict.setdefault('__clsconf__', list()).append(field_clsconf) + classdict.setdefault("__clsconf__", list()).append(field_clsconf) # add a specialized docval arg for __init__ docval_arg = dict( name=attr_name, doc=field_spec.doc, - type=(list, tuple, dict, cls._get_type(field_spec, type_map)) + type=(list, tuple, dict, cls._get_type(field_spec, type_map)), ) if cls._check_spec_optional(field_spec, spec): - docval_arg['default'] = getattr(field_spec, 'default_value', None) + docval_arg["default"] = getattr(field_spec, "default_value", None) cls._add_to_docval_args(docval_args, docval_arg) @classmethod @@ -368,7 +403,7 @@ def post_process(cls, classdict, bases, docval_args, spec): :param docval_args: The dict of docval arguments. :param spec: The spec for the container class to generate. """ - if '__clsconf__' in classdict: + if "__clsconf__" in classdict: # do not add MCI as a base if a base is already a subclass of MultiContainerInterface for b in bases: if issubclass(b, MultiContainerInterface): @@ -386,24 +421,24 @@ def post_process(cls, classdict, bases, docval_args, spec): @classmethod def set_init(cls, classdict, bases, docval_args, not_inherited_fields, name): - if '__clsconf__' in classdict: - previous_init = classdict['__init__'] + if "__clsconf__" in classdict: + previous_init = classdict["__init__"] @docval(*docval_args, allow_positional=AllowPositional.WARNING) def __init__(self, **kwargs): # store the values passed to init for each MCI attribute so that they can be added # after calling __init__ new_kwargs = list() - for field_clsconf in classdict['__clsconf__']: - attr_name = field_clsconf['attr'] + for field_clsconf in classdict["__clsconf__"]: + attr_name = field_clsconf["attr"] # do not store the value if it is None or not present if attr_name not in kwargs or kwargs[attr_name] is None: continue - add_method_name = field_clsconf['add'] + add_method_name = field_clsconf["add"] new_kwarg = dict( attr_name=attr_name, value=popargs(attr_name, kwargs), - add_method_name=add_method_name + add_method_name=add_method_name, ) new_kwargs.append(new_kwarg) @@ -417,8 +452,8 @@ def __init__(self, **kwargs): # call the add method for each MCI attribute for new_kwarg in new_kwargs: - add_method = getattr(self, new_kwarg['add_method_name']) - add_method(new_kwarg['value']) + add_method = getattr(self, new_kwarg["add_method_name"]) + add_method(new_kwarg["value"]) # override __init__ - classdict['__init__'] = __init__ + classdict["__init__"] = __init__ diff --git a/src/hdmf/build/errors.py b/src/hdmf/build/errors.py index ec31ef5ba..278249241 100644 --- a/src/hdmf/build/errors.py +++ b/src/hdmf/build/errors.py @@ -7,41 +7,50 @@ class BuildError(Exception): """Error raised when building a container into a builder.""" - @docval({'name': 'builder', 'type': Builder, 'doc': 'the builder that cannot be built'}, - {'name': 'reason', 'type': str, 'doc': 'the reason for the error'}) + @docval( + {"name": "builder", "type": Builder, "doc": "the builder that cannot be built"}, + {"name": "reason", "type": str, "doc": "the reason for the error"}, + ) def __init__(self, **kwargs): - self.__builder = getargs('builder', kwargs) - self.__reason = getargs('reason', kwargs) + self.__builder = getargs("builder", kwargs) + self.__reason = getargs("reason", kwargs) self.__message = "%s (%s): %s" % (self.__builder.name, self.__builder.path, self.__reason) super().__init__(self.__message) class OrphanContainerBuildError(BuildError): - - @docval({'name': 'builder', 'type': Builder, 'doc': 'the builder containing the broken link'}, - {'name': 'container', 'type': AbstractContainer, 'doc': 'the container that has no parent'}) + @docval( + {"name": "builder", "type": Builder, "doc": "the builder containing the broken link"}, + {"name": "container", "type": AbstractContainer, "doc": "the container that has no parent"}, + ) def __init__(self, **kwargs): - builder = getargs('builder', kwargs) - self.__container = getargs('container', kwargs) - reason = ("Linked %s '%s' has no parent. Remove the link or ensure the linked container is added properly." - % (self.__container.__class__.__name__, self.__container.name)) + builder = getargs("builder", kwargs) + self.__container = getargs("container", kwargs) + reason = "Linked %s '%s' has no parent. Remove the link or ensure the linked container is added properly." % ( + self.__container.__class__.__name__, + self.__container.name, + ) super().__init__(builder=builder, reason=reason) class ReferenceTargetNotBuiltError(BuildError): - - @docval({'name': 'builder', 'type': Builder, 'doc': 'the builder containing the reference that cannot be found'}, - {'name': 'container', 'type': AbstractContainer, 'doc': 'the container that is not built yet'}) + @docval( + {"name": "builder", "type": Builder, "doc": "the builder containing the reference that cannot be found"}, + {"name": "container", "type": AbstractContainer, "doc": "the container that is not built yet"}, + ) def __init__(self, **kwargs): - builder = getargs('builder', kwargs) - self.__container = getargs('container', kwargs) - reason = ("Could not find already-built Builder for %s '%s' in BuildManager" - % (self.__container.__class__.__name__, self.__container.name)) + builder = getargs("builder", kwargs) + self.__container = getargs("container", kwargs) + reason = "Could not find already-built Builder for %s '%s' in BuildManager" % ( + self.__container.__class__.__name__, + self.__container.name, + ) super().__init__(builder=builder, reason=reason) class ContainerConfigurationError(Exception): """Error raised when the container class is improperly configured.""" + pass diff --git a/src/hdmf/build/manager.py b/src/hdmf/build/manager.py index 03f2856b8..7ff2245f3 100644 --- a/src/hdmf/build/manager.py +++ b/src/hdmf/build/manager.py @@ -50,14 +50,14 @@ def data_type(self): @docval({"name": "object", "type": (BaseBuilder, Container), "doc": "the container or builder to get a proxy for"}) def matches(self, **kwargs): - obj = getargs('object', kwargs) + obj = getargs("object", kwargs) if not isinstance(obj, Proxy): obj = self.__manager.get_proxy(obj) return self == obj @docval({"name": "container", "type": Container, "doc": "the Container to add as a candidate match"}) def add_candidate(self, **kwargs): - container = getargs('container', kwargs) + container = getargs("container", kwargs) self.__candidates.append(container) def resolve(self): @@ -67,14 +67,16 @@ def resolve(self): raise ValueError("No matching candidate Container found for " + self) def __eq__(self, other): - return self.data_type == other.data_type and \ - self.location == other.location and \ - self.namespace == other.namespace and \ - self.source == other.source + return ( + self.data_type == other.data_type + and self.location == other.location + and self.namespace == other.namespace + and self.source == other.source + ) def __repr__(self): ret = dict() - for key in ('source', 'location', 'namespace', 'data_type'): + for key in ("source", "location", "namespace", "data_type"): ret[key] = getattr(self, key, None) return str(ret) @@ -85,7 +87,7 @@ class BuildManager: """ def __init__(self, type_map): - self.logger = logging.getLogger('%s.%s' % (self.__class__.__module__, self.__class__.__qualname__)) + self.logger = logging.getLogger("%s.%s" % (self.__class__.__module__, self.__class__.__qualname__)) self.__builders = dict() self.__containers = dict() self.__active_builders = set() @@ -100,12 +102,21 @@ def namespace_catalog(self): def type_map(self): return self.__type_map - @docval({"name": "object", "type": (BaseBuilder, AbstractContainer), - "doc": "the container or builder to get a proxy for"}, - {"name": "source", "type": str, - "doc": "the source of container being built i.e. file path", 'default': None}) + @docval( + { + "name": "object", + "type": (BaseBuilder, AbstractContainer), + "doc": "the container or builder to get a proxy for", + }, + { + "name": "source", + "type": str, + "doc": "the source of container being built i.e. file path", + "default": None, + }, + ) def get_proxy(self, **kwargs): - obj = getargs('object', kwargs) + obj = getargs("object", kwargs) if isinstance(obj, BaseBuilder): return self._get_proxy_builder(obj) elif isinstance(obj, AbstractContainer): @@ -136,26 +147,56 @@ def _get_proxy_container(self, container): loc = "/".join(reversed(stack)) return Proxy(self, container.container_source, loc, ns, dt) - @docval({"name": "container", "type": AbstractContainer, "doc": "the container to convert to a Builder"}, - {"name": "source", "type": str, - "doc": "the source of container being built i.e. file path", 'default': None}, - {"name": "spec_ext", "type": BaseStorageSpec, "doc": "a spec that further refines the base specification", - 'default': None}, - {"name": "export", "type": bool, "doc": "whether this build is for exporting", - 'default': False}, - {"name": "root", "type": bool, "doc": "whether the container is the root of the build process", - 'default': False}) + @docval( + { + "name": "container", + "type": AbstractContainer, + "doc": "the container to convert to a Builder", + }, + { + "name": "source", + "type": str, + "doc": "the source of container being built i.e. file path", + "default": None, + }, + { + "name": "spec_ext", + "type": BaseStorageSpec, + "doc": "a spec that further refines the base specification", + "default": None, + }, + { + "name": "export", + "type": bool, + "doc": "whether this build is for exporting", + "default": False, + }, + { + "name": "root", + "type": bool, + "doc": "whether the container is the root of the build process", + "default": False, + }, + ) def build(self, **kwargs): - """ Build the GroupBuilder/DatasetBuilder for the given AbstractContainer""" - container, export = getargs('container', 'export', kwargs) - source, spec_ext, root = getargs('source', 'spec_ext', 'root', kwargs) + """Build the GroupBuilder/DatasetBuilder for the given AbstractContainer""" + container, export = getargs("container", "export", kwargs) + source, spec_ext, root = getargs("source", "spec_ext", "root", kwargs) result = self.get_builder(container) if root: self.__active_builders.clear() # reset active builders at start of build process if result is None: - self.logger.debug("Building new %s '%s' (container_source: %s, source: %s, extended spec: %s, export: %s)" - % (container.__class__.__name__, container.name, repr(container.container_source), - repr(source), spec_ext is not None, export)) + self.logger.debug( + "Building new %s '%s' (container_source: %s, source: %s, extended spec: %s, export: %s)" + % ( + container.__class__.__name__, + container.name, + repr(container.container_source), + repr(source), + spec_ext is not None, + export, + ) + ) # the container_source is not set or checked when exporting if not export: if container.container_source is None: @@ -164,9 +205,10 @@ def build(self, **kwargs): source = container.container_source else: if container.container_source != source: - raise ValueError("Cannot change container_source once set: '%s' %s.%s" - % (container.name, container.__class__.__module__, - container.__class__.__name__)) + raise ValueError( + "Cannot change container_source once set: '%s' %s.%s" + % (container.name, container.__class__.__module__, container.__class__.__name__) + ) # NOTE: if exporting, then existing cached builder will be ignored and overridden with new build result result = self.__type_map.build(container, self, source=source, spec_ext=spec_ext, export=export) self.prebuilt(container, result) @@ -174,27 +216,44 @@ def build(self, **kwargs): self.logger.debug("Done building %s '%s'" % (container.__class__.__name__, container.name)) elif not self.__is_active_builder(result) and container.modified: # if builder was built on file read and is then modified (append mode), it needs to be rebuilt - self.logger.debug("Rebuilding modified %s '%s' (source: %s, extended spec: %s)" - % (container.__class__.__name__, container.name, - repr(source), spec_ext is not None)) - result = self.__type_map.build(container, self, builder=result, source=source, spec_ext=spec_ext, - export=export) + self.logger.debug( + "Rebuilding modified %s '%s' (source: %s, extended spec: %s)" + % (container.__class__.__name__, container.name, repr(source), spec_ext is not None) + ) + result = self.__type_map.build( + container, + self, + builder=result, + source=source, + spec_ext=spec_ext, + export=export, + ) self.logger.debug("Done rebuilding %s '%s'" % (container.__class__.__name__, container.name)) else: - self.logger.debug("Using prebuilt %s '%s' for %s '%s'" - % (result.__class__.__name__, result.name, - container.__class__.__name__, container.name)) + self.logger.debug( + "Using prebuilt %s '%s' for %s '%s'" + % (result.__class__.__name__, result.name, container.__class__.__name__, container.name) + ) if root: # create reference builders only after building all other builders self.__add_refs() self.__active_builders.clear() # reset active builders now that build process has completed return result - @docval({"name": "container", "type": AbstractContainer, "doc": "the AbstractContainer to save as prebuilt"}, - {'name': 'builder', 'type': (DatasetBuilder, GroupBuilder), - 'doc': 'the Builder representation of the given container'}) + @docval( + { + "name": "container", + "type": AbstractContainer, + "doc": "the AbstractContainer to save as prebuilt", + }, + { + "name": "builder", + "type": (DatasetBuilder, GroupBuilder), + "doc": "the Builder representation of the given container", + }, + ) def prebuilt(self, **kwargs): - ''' Save the Builder for a given AbstractContainer for future use ''' - container, builder = getargs('container', 'builder', kwargs) + """Save the Builder for a given AbstractContainer for future use""" + container, builder = getargs("container", "builder", kwargs) container_id = self.__conthash__(container) self.__builders[container_id] = builder builder_id = self.__bldrhash__(builder) @@ -217,7 +276,7 @@ def __bldrhash__(self, obj): return id(obj) def __add_refs(self): - ''' + """ Add ReferenceBuilders. References get queued to be added after all other objects are built. This is because @@ -225,15 +284,16 @@ def __add_refs(self): does not happen in a guaranteed order. We need to build the targets of the reference builders so that the targets have the proper parent, and then write the reference builders after we write everything else. - ''' + """ while len(self.__ref_queue) > 0: call = self.__ref_queue.popleft() - self.logger.debug("Adding ReferenceBuilder with call id %d from queue (length %d)" - % (id(call), len(self.__ref_queue))) + self.logger.debug( + "Adding ReferenceBuilder with call id %d from queue (length %d)" % (id(call), len(self.__ref_queue)) + ) call() def queue_ref(self, func): - '''Set aside creating ReferenceBuilders''' + """Set aside creating ReferenceBuilders""" # TODO: come up with more intelligent way of # queueing reference resolution, based on reference # dependency @@ -246,9 +306,10 @@ def purge_outdated(self): container_id = self.__conthash__(container) builder = self.__builders.get(container_id) builder_id = self.__bldrhash__(builder) - self.logger.debug("Purging %s '%s' for %s '%s' from prebuilt cache" - % (builder.__class__.__name__, builder.name, - container.__class__.__name__, container.name)) + self.logger.debug( + "Purging %s '%s' for %s '%s' from prebuilt cache" + % (builder.__class__.__name__, builder.name, container.__class__.__name__, container.name) + ) self.__builders.pop(container_id) self.__containers.pop(builder_id) @@ -259,16 +320,21 @@ def clear_cache(self): @docval({"name": "container", "type": AbstractContainer, "doc": "the container to get the builder for"}) def get_builder(self, **kwargs): """Return the prebuilt builder for the given container or None if it does not exist.""" - container = getargs('container', kwargs) + container = getargs("container", kwargs) container_id = self.__conthash__(container) result = self.__builders.get(container_id) return result - @docval({'name': 'builder', 'type': (DatasetBuilder, GroupBuilder), - 'doc': 'the builder to construct the AbstractContainer from'}) + @docval( + { + "name": "builder", + "type": (DatasetBuilder, GroupBuilder), + "doc": "the builder to construct the AbstractContainer from", + } + ) def construct(self, **kwargs): - """ Construct the AbstractContainer represented by the given builder """ - builder = getargs('builder', kwargs) + """Construct the AbstractContainer represented by the given builder""" + builder = getargs("builder", kwargs) if isinstance(builder, LinkBuilder): builder = builder.target builder_id = self.__bldrhash__(builder) @@ -297,10 +363,10 @@ def __resolve_parents(self, container): stack.append(child) def __get_parent_dt_builder(self, builder): - ''' + """ Get the next builder above the given builder that has a data_type - ''' + """ tmp = builder.parent ret = None while tmp is not None: @@ -313,54 +379,84 @@ def __get_parent_dt_builder(self, builder): # *** The following methods just delegate calls to self.__type_map *** - @docval({'name': 'builder', 'type': Builder, 'doc': 'the Builder to get the class object for'}) + @docval({"name": "builder", "type": Builder, "doc": "the Builder to get the class object for"}) def get_cls(self, **kwargs): - ''' Get the class object for the given Builder ''' - builder = getargs('builder', kwargs) + """Get the class object for the given Builder""" + builder = getargs("builder", kwargs) return self.__type_map.get_cls(builder) - @docval({"name": "container", "type": AbstractContainer, "doc": "the container to convert to a Builder"}, - returns='The name a Builder should be given when building this container', rtype=str) + @docval( + {"name": "container", "type": AbstractContainer, "doc": "the container to convert to a Builder"}, + returns="The name a Builder should be given when building this container", + rtype=str, + ) def get_builder_name(self, **kwargs): - ''' Get the name a Builder should be given ''' - container = getargs('container', kwargs) + """Get the name a Builder should be given""" + container = getargs("container", kwargs) return self.__type_map.get_builder_name(container) - @docval({'name': 'spec', 'type': (DatasetSpec, GroupSpec), 'doc': 'the parent spec to search'}, - {'name': 'builder', 'type': (DatasetBuilder, GroupBuilder, LinkBuilder), - 'doc': 'the builder to get the sub-specification for'}) + @docval( + { + "name": "spec", + "type": (DatasetSpec, GroupSpec), + "doc": "the parent spec to search", + }, + { + "name": "builder", + "type": (DatasetBuilder, GroupBuilder, LinkBuilder), + "doc": "the builder to get the sub-specification for", + }, + ) def get_subspec(self, **kwargs): - ''' Get the specification from this spec that corresponds to the given builder ''' - spec, builder = getargs('spec', 'builder', kwargs) + """Get the specification from this spec that corresponds to the given builder""" + spec, builder = getargs("spec", "builder", kwargs) return self.__type_map.get_subspec(spec, builder) - @docval({'name': 'builder', 'type': (DatasetBuilder, GroupBuilder, LinkBuilder), - 'doc': 'the builder to get the sub-specification for'}) + @docval( + { + "name": "builder", + "type": (DatasetBuilder, GroupBuilder, LinkBuilder), + "doc": "the builder to get the sub-specification for", + } + ) def get_builder_ns(self, **kwargs): - ''' Get the namespace of a builder ''' - builder = getargs('builder', kwargs) + """Get the namespace of a builder""" + builder = getargs("builder", kwargs) return self.__type_map.get_builder_ns(builder) - @docval({'name': 'builder', 'type': (DatasetBuilder, GroupBuilder, LinkBuilder), - 'doc': 'the builder to get the data_type for'}) + @docval( + { + "name": "builder", + "type": (DatasetBuilder, GroupBuilder, LinkBuilder), + "doc": "the builder to get the data_type for", + } + ) def get_builder_dt(self, **kwargs): - ''' + """ Get the data_type of a builder - ''' - builder = getargs('builder', kwargs) + """ + builder = getargs("builder", kwargs) return self.__type_map.get_builder_dt(builder) - @docval({'name': 'builder', 'type': (GroupBuilder, DatasetBuilder, AbstractContainer), - 'doc': 'the builder or container to check'}, - {'name': 'parent_data_type', 'type': str, - 'doc': 'the potential parent data_type that refers to a data_type'}, - returns="True if data_type of *builder* is a sub-data_type of *parent_data_type*, False otherwise", - rtype=bool) + @docval( + { + "name": "builder", + "type": (GroupBuilder, DatasetBuilder, AbstractContainer), + "doc": "the builder or container to check", + }, + { + "name": "parent_data_type", + "type": str, + "doc": "the potential parent data_type that refers to a data_type", + }, + returns="True if data_type of *builder* is a sub-data_type of *parent_data_type*, False otherwise", + rtype=bool, + ) def is_sub_data_type(self, **kwargs): - ''' + """ Return whether or not data_type of *builder* is a sub-data_type of *parent_data_type* - ''' - builder, parent_dt = getargs('builder', 'parent_data_type', kwargs) + """ + builder, parent_dt = getargs("builder", "parent_data_type", kwargs) if isinstance(builder, (GroupBuilder, DatasetBuilder)): ns = self.get_builder_ns(builder) dt = self.get_builder_dt(builder) @@ -370,14 +466,24 @@ def is_sub_data_type(self, **kwargs): class TypeSource: - '''A class to indicate the source of a data_type in a namespace. + """A class to indicate the source of a data_type in a namespace. This class should only be used by TypeMap - ''' + """ - @docval({"name": "namespace", "type": str, "doc": "the namespace the from, which the data_type originated"}, - {"name": "data_type", "type": str, "doc": "the name of the type"}) + @docval( + { + "name": "namespace", + "type": str, + "doc": "the namespace the from, which the data_type originated", + }, + { + "name": "data_type", + "type": str, + "doc": "the name of the type", + }, + ) def __init__(self, **kwargs): - namespace, data_type = getargs('namespace', 'data_type', kwargs) + namespace, data_type = getargs("namespace", "data_type", kwargs) self.__namespace = namespace self.__data_type = data_type @@ -391,17 +497,19 @@ def data_type(self): class TypeMap: - ''' A class to maintain the map between ObjectMappers and AbstractContainer classes - ''' + """A class to maintain the map between ObjectMappers and AbstractContainer classes""" - @docval({'name': 'namespaces', 'type': NamespaceCatalog, 'doc': 'the NamespaceCatalog to use', 'default': None}, - {'name': 'mapper_cls', 'type': type, 'doc': 'the ObjectMapper class to use', 'default': None}) + @docval( + {"name": "namespaces", "type": NamespaceCatalog, "doc": "the NamespaceCatalog to use", "default": None}, + {"name": "mapper_cls", "type": type, "doc": "the ObjectMapper class to use", "default": None}, + ) def __init__(self, **kwargs): - namespaces, mapper_cls = getargs('namespaces', 'mapper_cls', kwargs) + namespaces, mapper_cls = getargs("namespaces", "mapper_cls", kwargs) if namespaces is None: namespaces = NamespaceCatalog() if mapper_cls is None: from .objectmapper import ObjectMapper # avoid circular import + mapper_cls = ObjectMapper self.__ns_catalog = namespaces self.__mappers = dict() # already constructed ObjectMapper classes @@ -459,17 +567,20 @@ def merge(self, type_map, ns_catalog=False): @docval({"name": "generator", "type": type, "doc": "the CustomClassGenerator class to register"}) def register_generator(self, **kwargs): """Add a custom class generator.""" - generator = getargs('generator', kwargs) + generator = getargs("generator", kwargs) self.__class_generator.register_generator(generator) - @docval(*get_docval(NamespaceCatalog.load_namespaces), - returns="the namespaces loaded from the given file", rtype=dict) + @docval( + *get_docval(NamespaceCatalog.load_namespaces), + returns="the namespaces loaded from the given file", + rtype=dict, + ) def load_namespaces(self, **kwargs): - '''Load namespaces from a namespace file. + """Load namespaces from a namespace file. This method will call load_namespaces on the NamespaceCatalog used to construct this TypeMap. Additionally, it will process the return value to keep track of what types were included in the loaded namespaces. Calling load_namespaces here has the advantage of being able to keep track of type dependencies across namespaces. - ''' + """ deps = self.__ns_catalog.load_namespaces(**kwargs) for new_ns, ns_deps in deps.items(): for src_ns, types in ns_deps.items(): @@ -480,23 +591,56 @@ def load_namespaces(self, **kwargs): self.register_container_type(new_ns, dt, container_cls) return deps - @docval({"name": "namespace", "type": str, "doc": "the namespace containing the data_type"}, - {"name": "data_type", "type": str, "doc": "the data type to create a AbstractContainer class for"}, - {"name": "autogen", "type": bool, "doc": "autogenerate class if one does not exist", "default": True}, - returns='the class for the given namespace and data_type', rtype=type) + @docval( + { + "name": "namespace", + "type": str, + "doc": "the namespace containing the data_type", + }, + { + "name": "data_type", + "type": str, + "doc": "the data type to create a AbstractContainer class for", + }, + { + "name": "autogen", + "type": bool, + "doc": "autogenerate class if one does not exist", + "default": True, + }, + returns="the class for the given namespace and data_type", + rtype=type, + ) def get_container_cls(self, **kwargs): """Get the container class from data type specification. If no class has been associated with the ``data_type`` from ``namespace``, a class will be dynamically created and returned. """ # NOTE: this internally used function get_container_cls will be removed in favor of get_dt_container_cls - namespace, data_type, autogen = getargs('namespace', 'data_type', 'autogen', kwargs) + namespace, data_type, autogen = getargs("namespace", "data_type", "autogen", kwargs) return self.get_dt_container_cls(data_type, namespace, autogen) - @docval({"name": "data_type", "type": str, "doc": "the data type to create a AbstractContainer class for"}, - {"name": "namespace", "type": str, "doc": "the namespace containing the data_type", "default": None}, - {"name": "autogen", "type": bool, "doc": "autogenerate class if one does not exist", "default": True}, - returns='the class for the given namespace and data_type', rtype=type) + @docval( + { + "name": "data_type", + "type": str, + "doc": "the data type to create a AbstractContainer class for", + }, + { + "name": "namespace", + "type": str, + "doc": "the namespace containing the data_type", + "default": None, + }, + { + "name": "autogen", + "type": bool, + "doc": "autogenerate class if one does not exist", + "default": True, + }, + returns="the class for the given namespace and data_type", + rtype=type, + ) def get_dt_container_cls(self, **kwargs): """Get the container class from data type specification. If no class has been associated with the ``data_type`` from ``namespace``, a class will be dynamically @@ -505,7 +649,7 @@ def get_dt_container_cls(self, **kwargs): Replaces get_container_cls but namespace is optional. If namespace is unknown, it will be looked up from all namespaces. """ - namespace, data_type, autogen = getargs('namespace', 'data_type', 'autogen', kwargs) + namespace, data_type, autogen = getargs("namespace", "data_type", "autogen", kwargs) # namespace is unknown, so look it up if namespace is None: @@ -529,8 +673,8 @@ def get_dt_container_cls(self, **kwargs): return cls def __check_dependent_types(self, spec, namespace): - """Ensure that classes for all types used by this type exist in this namespace and generate them if not. - """ + """Ensure that classes for all types used by this type exist in this namespace and generate them if not.""" + def __check_dependent_types_helper(spec, namespace): if isinstance(spec, (GroupSpec, DatasetSpec)): if spec.data_type_inc is not None: @@ -540,13 +684,13 @@ def __check_dependent_types_helper(spec, namespace): else: # spec is a LinkSpec self.get_dt_container_cls(spec.target_type, namespace) if isinstance(spec, GroupSpec): - for child_spec in (spec.groups + spec.datasets + spec.links): + for child_spec in spec.groups + spec.datasets + spec.links: __check_dependent_types_helper(child_spec, namespace) if spec.data_type_inc is not None: self.get_dt_container_cls(spec.data_type_inc, namespace) if isinstance(spec, GroupSpec): - for child_spec in (spec.groups + spec.datasets + spec.links): + for child_spec in spec.groups + spec.datasets + spec.links: __check_dependent_types_helper(child_spec, namespace) def __get_parent_cls(self, namespace, data_type, spec): @@ -582,8 +726,13 @@ def __get_container_cls(self, namespace, data_type): ret = cls return ret - @docval({'name': 'obj', 'type': (GroupBuilder, DatasetBuilder, LinkBuilder, GroupSpec, DatasetSpec), - 'doc': 'the object to get the type key for'}) + @docval( + { + "name": "obj", + "type": (GroupBuilder, DatasetBuilder, LinkBuilder, GroupSpec, DatasetSpec), + "doc": "the object to get the type key for", + } + ) def __type_key(self, obj): """ A wrapper function to simplify the process of getting a type_key for an object. @@ -596,13 +745,18 @@ def __type_key(self, obj): else: return self.__ns_catalog.dataset_spec_cls.type_key() - @docval({'name': 'builder', 'type': (DatasetBuilder, GroupBuilder, LinkBuilder), - 'doc': 'the builder to get the data_type for'}) + @docval( + { + "name": "builder", + "type": (DatasetBuilder, GroupBuilder, LinkBuilder), + "doc": "the builder to get the data_type for", + } + ) def get_builder_dt(self, **kwargs): - ''' + """ Get the data_type of a builder - ''' - builder = getargs('builder', kwargs) + """ + builder = getargs("builder", kwargs) ret = None if isinstance(builder, LinkBuilder): builder = builder.builder @@ -611,24 +765,34 @@ def get_builder_dt(self, **kwargs): else: ret = builder.attributes.get(self.__ns_catalog.dataset_spec_cls.type_key()) if isinstance(ret, bytes): - ret = ret.decode('UTF-8') + ret = ret.decode("UTF-8") return ret - @docval({'name': 'builder', 'type': (DatasetBuilder, GroupBuilder, LinkBuilder), - 'doc': 'the builder to get the sub-specification for'}) + @docval( + { + "name": "builder", + "type": (DatasetBuilder, GroupBuilder, LinkBuilder), + "doc": "the builder to get the sub-specification for", + } + ) def get_builder_ns(self, **kwargs): - ''' Get the namespace of a builder ''' - builder = getargs('builder', kwargs) + """Get the namespace of a builder""" + builder = getargs("builder", kwargs) if isinstance(builder, LinkBuilder): builder = builder.builder - ret = builder.attributes.get('namespace') + ret = builder.attributes.get("namespace") return ret - @docval({'name': 'builder', 'type': Builder, - 'doc': 'the Builder object to get the corresponding AbstractContainer class for'}) + @docval( + { + "name": "builder", + "type": Builder, + "doc": "the Builder object to get the corresponding AbstractContainer class for", + } + ) def get_cls(self, **kwargs): - ''' Get the class object for the given Builder ''' - builder = getargs('builder', kwargs) + """Get the class object for the given Builder""" + builder = getargs("builder", kwargs) data_type = self.get_builder_dt(builder) if data_type is None: raise ValueError("No data_type found for builder %s" % builder.path) @@ -637,12 +801,21 @@ def get_cls(self, **kwargs): raise ValueError("No namespace found for builder %s" % builder.path) return self.get_dt_container_cls(data_type, namespace) - @docval({'name': 'spec', 'type': (DatasetSpec, GroupSpec), 'doc': 'the parent spec to search'}, - {'name': 'builder', 'type': (DatasetBuilder, GroupBuilder, LinkBuilder), - 'doc': 'the builder to get the sub-specification for'}) + @docval( + { + "name": "spec", + "type": (DatasetSpec, GroupSpec), + "doc": "the parent spec to search", + }, + { + "name": "builder", + "type": (DatasetBuilder, GroupBuilder, LinkBuilder), + "doc": "the builder to get the sub-specification for", + }, + ) def get_subspec(self, **kwargs): - ''' Get the specification from this spec that corresponds to the given builder ''' - spec, builder = getargs('spec', 'builder', kwargs) + """Get the specification from this spec that corresponds to the given builder""" + spec, builder = getargs("spec", "builder", kwargs) if isinstance(builder, LinkBuilder): builder_type = type(builder.builder) # TODO consider checking against spec.get_link @@ -683,20 +856,33 @@ def get_container_cls_dt(self, cls): return ret return ret - @docval({'name': 'namespace', 'type': str, - 'doc': 'the namespace to get the container classes for', 'default': None}) + @docval( + { + "name": "namespace", + "type": str, + "doc": "the namespace to get the container classes for", + "default": None, + } + ) def get_container_classes(self, **kwargs): - namespace = getargs('namespace', kwargs) + namespace = getargs("namespace", kwargs) ret = self.__data_types.keys() if namespace is not None: ret = filter(lambda x: self.__data_types[x][0] == namespace, ret) return list(ret) - @docval({'name': 'obj', 'type': (AbstractContainer, Builder), 'doc': 'the object to get the ObjectMapper for'}, - returns='the ObjectMapper to use for mapping the given object', rtype='ObjectMapper') + @docval( + { + "name": "obj", + "type": (AbstractContainer, Builder), + "doc": "the object to get the ObjectMapper for", + }, + returns="the ObjectMapper to use for mapping the given object", + rtype="ObjectMapper", + ) def get_map(self, **kwargs): - """ Return the ObjectMapper object that should be used for the given container """ - obj = getargs('obj', kwargs) + """Return the ObjectMapper object that should be used for the given container""" + obj = getargs("obj", kwargs) # get the container class, and namespace/data_type if isinstance(obj, AbstractContainer): container_cls = obj.__class__ @@ -721,86 +907,164 @@ def get_map(self, **kwargs): self.__mappers[container_cls] = mapper return mapper - @docval({"name": "namespace", "type": str, "doc": "the namespace containing the data_type to map the class to"}, - {"name": "data_type", "type": str, "doc": "the data_type to map the class to"}, - {"name": "container_cls", "type": (TypeSource, type), "doc": "the class to map to the specified data_type"}) + @docval( + { + "name": "namespace", + "type": str, + "doc": "the namespace containing the data_type to map the class to", + }, + { + "name": "data_type", + "type": str, + "doc": "the data_type to map the class to", + }, + { + "name": "container_cls", + "type": (TypeSource, type), + "doc": "the class to map to the specified data_type", + }, + ) def register_container_type(self, **kwargs): - ''' Map a container class to a data_type ''' - namespace, data_type, container_cls = getargs('namespace', 'data_type', 'container_cls', kwargs) + """Map a container class to a data_type""" + namespace, data_type, container_cls = getargs("namespace", "data_type", "container_cls", kwargs) spec = self.__ns_catalog.get_spec(namespace, data_type) # make sure the spec exists self.__container_types.setdefault(namespace, dict()) self.__container_types[namespace][data_type] = container_cls self.__data_types.setdefault(container_cls, (namespace, data_type)) if not isinstance(container_cls, TypeSource): setattr(container_cls, spec.type_key(), data_type) - setattr(container_cls, 'namespace', namespace) - - @docval({"name": "container_cls", "type": type, - "doc": "the AbstractContainer class for which the given ObjectMapper class gets used for"}, - {"name": "mapper_cls", "type": type, "doc": "the ObjectMapper class to use to map"}) + setattr(container_cls, "namespace", namespace) + + @docval( + { + "name": "container_cls", + "type": type, + "doc": "the AbstractContainer class for which the given ObjectMapper class gets used for", + }, + { + "name": "mapper_cls", + "type": type, + "doc": "the ObjectMapper class to use to map", + }, + ) def register_map(self, **kwargs): - ''' Map a container class to an ObjectMapper class ''' - container_cls, mapper_cls = getargs('container_cls', 'mapper_cls', kwargs) + """Map a container class to an ObjectMapper class""" + container_cls, mapper_cls = getargs("container_cls", "mapper_cls", kwargs) if self.get_container_cls_dt(container_cls) == (None, None): - raise ValueError('cannot register map for type %s - no data_type found' % container_cls) + raise ValueError("cannot register map for type %s - no data_type found" % container_cls) self.__mapper_cls[container_cls] = mapper_cls - @docval({"name": "container", "type": AbstractContainer, "doc": "the container to convert to a Builder"}, - {"name": "manager", "type": BuildManager, - "doc": "the BuildManager to use for managing this build", 'default': None}, - {"name": "source", "type": str, - "doc": "the source of container being built i.e. file path", 'default': None}, - {"name": "builder", "type": BaseBuilder, "doc": "the Builder to build on", 'default': None}, - {"name": "spec_ext", "type": BaseStorageSpec, "doc": "a spec extension", 'default': None}, - {"name": "export", "type": bool, "doc": "whether this build is for exporting", - 'default': False}) + @docval( + { + "name": "container", + "type": AbstractContainer, + "doc": "the container to convert to a Builder", + }, + { + "name": "manager", + "type": BuildManager, + "doc": "the BuildManager to use for managing this build", + "default": None, + }, + { + "name": "source", + "type": str, + "doc": "the source of container being built i.e. file path", + "default": None, + }, + { + "name": "builder", + "type": BaseBuilder, + "doc": "the Builder to build on", + "default": None, + }, + { + "name": "spec_ext", + "type": BaseStorageSpec, + "doc": "a spec extension", + "default": None, + }, + { + "name": "export", + "type": bool, + "doc": "whether this build is for exporting", + "default": False, + }, + ) def build(self, **kwargs): """Build the GroupBuilder/DatasetBuilder for the given AbstractContainer""" - container, manager, builder = getargs('container', 'manager', 'builder', kwargs) - source, spec_ext, export = getargs('source', 'spec_ext', 'export', kwargs) + container, manager, builder = getargs("container", "manager", "builder", kwargs) + source, spec_ext, export = getargs("source", "spec_ext", "export", kwargs) # get the ObjectMapper to map between Spec objects and AbstractContainer attributes obj_mapper = self.get_map(container) if obj_mapper is None: - raise ValueError('No ObjectMapper found for container of type %s' % str(container.__class__.__name__)) + raise ValueError("No ObjectMapper found for container of type %s" % str(container.__class__.__name__)) # convert the container to a builder using the ObjectMapper if manager is None: manager = BuildManager(self) - builder = obj_mapper.build(container, manager, builder=builder, source=source, spec_ext=spec_ext, export=export) + builder = obj_mapper.build( + container, + manager, + builder=builder, + source=source, + spec_ext=spec_ext, + export=export, + ) # add additional attributes (namespace, data_type, object_id) to builder namespace, data_type = self.get_container_ns_dt(container) - builder.set_attribute('namespace', namespace) + builder.set_attribute("namespace", namespace) builder.set_attribute(self.__type_key(obj_mapper.spec), data_type) builder.set_attribute(obj_mapper.spec.id_key(), container.object_id) return builder - @docval({'name': 'builder', 'type': (DatasetBuilder, GroupBuilder), - 'doc': 'the builder to construct the AbstractContainer from'}, - {'name': 'build_manager', 'type': BuildManager, - 'doc': 'the BuildManager for constructing', 'default': None}, - {'name': 'parent', 'type': (Proxy, Container), - 'doc': 'the parent Container/Proxy for the Container being built', 'default': None}) + @docval( + { + "name": "builder", + "type": (DatasetBuilder, GroupBuilder), + "doc": "the builder to construct the AbstractContainer from", + }, + { + "name": "build_manager", + "type": BuildManager, + "doc": "the BuildManager for constructing", + "default": None, + }, + { + "name": "parent", + "type": (Proxy, Container), + "doc": "the parent Container/Proxy for the Container being built", + "default": None, + }, + ) def construct(self, **kwargs): - """ Construct the AbstractContainer represented by the given builder """ - builder, build_manager, parent = getargs('builder', 'build_manager', 'parent', kwargs) + """Construct the AbstractContainer represented by the given builder""" + builder, build_manager, parent = getargs("builder", "build_manager", "parent", kwargs) if build_manager is None: build_manager = BuildManager(self) obj_mapper = self.get_map(builder) if obj_mapper is None: dt = builder.attributes[self.namespace_catalog.group_spec_cls.type_key()] - raise ValueError('No ObjectMapper found for builder of type %s' % dt) + raise ValueError("No ObjectMapper found for builder of type %s" % dt) else: return obj_mapper.construct(builder, build_manager, parent) - @docval({"name": "container", "type": AbstractContainer, "doc": "the container to convert to a Builder"}, - returns='The name a Builder should be given when building this container', rtype=str) + @docval( + { + "name": "container", + "type": AbstractContainer, + "doc": "the container to convert to a Builder", + }, + returns="The name a Builder should be given when building this container", + rtype=str, + ) def get_builder_name(self, **kwargs): - ''' Get the name a Builder should be given ''' - container = getargs('container', kwargs) + """Get the name a Builder should be given""" + container = getargs("container", kwargs) obj_mapper = self.get_map(container) if obj_mapper is None: - raise ValueError('No ObjectMapper found for container of type %s' % str(container.__class__.__name__)) + raise ValueError("No ObjectMapper found for container of type %s" % str(container.__class__.__name__)) else: return obj_mapper.get_builder_name(container) diff --git a/src/hdmf/build/map.py b/src/hdmf/build/map.py index 92b0c7499..53c7aff2b 100644 --- a/src/hdmf/build/map.py +++ b/src/hdmf/build/map.py @@ -3,5 +3,8 @@ from .objectmapper import ObjectMapper # noqa: F401 import warnings -warnings.warn('Classes in map.py should be imported from hdmf.build. Importing from hdmf.build.map will be removed ' - 'in HDMF 3.0.', DeprecationWarning) + +warnings.warn( + "Classes in map.py should be imported from hdmf.build. Importing from hdmf.build.map will be removed in HDMF 3.0.", + DeprecationWarning, +) diff --git a/src/hdmf/build/objectmapper.py b/src/hdmf/build/objectmapper.py index 9786981c5..bb387cc60 100644 --- a/src/hdmf/build/objectmapper.py +++ b/src/hdmf/build/objectmapper.py @@ -7,11 +7,28 @@ import numpy as np -from .builders import DatasetBuilder, GroupBuilder, LinkBuilder, Builder, ReferenceBuilder, RegionBuilder, BaseBuilder -from .errors import (BuildError, OrphanContainerBuildError, ReferenceTargetNotBuiltError, ContainerConfigurationError, - ConstructError) +from .builders import ( + DatasetBuilder, + GroupBuilder, + LinkBuilder, + Builder, + ReferenceBuilder, + RegionBuilder, + BaseBuilder, +) +from .errors import ( + BuildError, + OrphanContainerBuildError, + ReferenceTargetNotBuiltError, + ContainerConfigurationError, + ConstructError, +) from .manager import Proxy, BuildManager -from .warnings import MissingRequiredBuildWarning, DtypeConversionWarning, IncorrectQuantityBuildWarning +from .warnings import ( + MissingRequiredBuildWarning, + DtypeConversionWarning, + IncorrectQuantityBuildWarning, +) from ..container import AbstractContainer, Data, DataRegion from ..data_utils import DataIO, AbstractDataChunkIterator from ..query import ReferenceResolver @@ -19,20 +36,22 @@ from ..spec.spec import BaseStorageSpec from ..utils import docval, getargs, ExtenderMeta, get_docval -_const_arg = '__constructor_arg' +_const_arg = "__constructor_arg" -@docval({'name': 'name', 'type': str, 'doc': 'the name of the constructor argument'}, - is_method=False) +@docval( + {"name": "name", "type": str, "doc": "the name of the constructor argument"}, + is_method=False, +) def _constructor_arg(**kwargs): - '''Decorator to override the default mapping scheme for a given constructor argument. + """Decorator to override the default mapping scheme for a given constructor argument. Decorate ObjectMapper methods with this function when extending ObjectMapper to override the default scheme for mapping between AbstractContainer and Builder objects. The decorated method should accept as its first argument the Builder object that is being mapped. The method should return the value to be passed to the target AbstractContainer class constructor argument given by *name*. - ''' - name = getargs('name', kwargs) + """ + name = getargs("name", kwargs) def _dec(func): setattr(func, _const_arg, name) @@ -41,21 +60,23 @@ def _dec(func): return _dec -_obj_attr = '__object_attr' +_obj_attr = "__object_attr" -@docval({'name': 'name', 'type': str, 'doc': 'the name of the constructor argument'}, - is_method=False) +@docval( + {"name": "name", "type": str, "doc": "the name of the constructor argument"}, + is_method=False, +) def _object_attr(**kwargs): - '''Decorator to override the default mapping scheme for a given object attribute. + """Decorator to override the default mapping scheme for a given object attribute. Decorate ObjectMapper methods with this function when extending ObjectMapper to override the default scheme for mapping between AbstractContainer and Builder objects. The decorated method should accept as its first argument the AbstractContainer object that is being mapped. The method should return the child Builder object (or scalar if the object attribute corresponds to an AttributeSpec) that represents the attribute given by *name*. - ''' - name = getargs('name', kwargs) + """ + name = getargs("name", kwargs) def _dec(func): setattr(func, _obj_attr, name) @@ -71,7 +92,7 @@ def _unicode(s): if isinstance(s, str): return s elif isinstance(s, bytes): - return s.decode('utf-8') + return s.decode("utf-8") else: raise ValueError("Expected unicode or ascii string, got %s" % type(s)) @@ -81,7 +102,7 @@ def _ascii(s): A helper function for converting to ASCII """ if isinstance(s, str): - return s.encode('ascii', 'backslashreplace') + return s.encode("ascii", "backslashreplace") elif isinstance(s, bytes): return s else: @@ -89,9 +110,7 @@ def _ascii(s): class ObjectMapper(metaclass=ExtenderMeta): - '''A class for mapping between Spec objects and AbstractContainer attributes - - ''' + """A class for mapping between Spec objects and AbstractContainer attributes""" # mapping from spec dtypes to numpy dtypes or functions for conversion of values to spec dtypes # make sure keys are consistent between hdmf.spec.spec.DtypeHelper.primary_dtype_synonyms, @@ -143,35 +162,43 @@ def __resolve_numeric_dtype(cls, given, specified): if g.itemsize <= s.itemsize: # given type has precision < precision of specified type # note: this allows float32 -> int32, bool -> int8, int16 -> uint16 which may involve buffer overflows, # truncated values, and other unexpected consequences. - warning_msg = ('Value with data type %s is being converted to data type %s as specified.' - % (g.name, s.name)) + warning_msg = "Value with data type %s is being converted to data type %s as specified." % (g.name, s.name) return s.type, warning_msg elif g.name[:3] == s.name[:3]: return g.type, None # same base type, use higher-precision given type else: if np.issubdtype(s, np.unsignedinteger): # e.g.: given int64 and spec uint32, return uint64. given float32 and spec uint8, return uint32. - ret_type = np.dtype('uint' + str(int(g.itemsize * 8))) - warning_msg = ('Value with data type %s is being converted to data type %s (min specification: %s).' - % (g.name, ret_type.name, s.name)) + ret_type = np.dtype("uint" + str(int(g.itemsize * 8))) + warning_msg = "Value with data type %s is being converted to data type %s (min specification: %s)." % ( + g.name, + ret_type.name, + s.name, + ) return ret_type.type, warning_msg if np.issubdtype(s, np.floating): # e.g.: given int64 and spec float32, return float64. given uint64 and spec float32, return float32. - ret_type = np.dtype('float' + str(max(int(g.itemsize * 8), 32))) - warning_msg = ('Value with data type %s is being converted to data type %s (min specification: %s).' - % (g.name, ret_type.name, s.name)) + ret_type = np.dtype("float" + str(max(int(g.itemsize * 8), 32))) + warning_msg = "Value with data type %s is being converted to data type %s (min specification: %s)." % ( + g.name, + ret_type.name, + s.name, + ) return ret_type.type, warning_msg if np.issubdtype(s, np.integer): # e.g.: given float64 and spec int8, return int64. given uint32 and spec int8, return int32. - ret_type = np.dtype('int' + str(int(g.itemsize * 8))) - warning_msg = ('Value with data type %s is being converted to data type %s (min specification: %s).' - % (g.name, ret_type.name, s.name)) + ret_type = np.dtype("int" + str(int(g.itemsize * 8))) + warning_msg = "Value with data type %s is being converted to data type %s (min specification: %s)." % ( + g.name, + ret_type.name, + s.name, + ) return ret_type.type, warning_msg if s.type is np.bool_: msg = "expected %s, received %s - must supply %s" % (s.name, g.name, s.name) raise ValueError(msg) # all numeric types in __dtypes should be caught by the above - raise ValueError('Unsupported conversion to specification data type: %s' % s.name) + raise ValueError("Unsupported conversion to specification data type: %s" % s.name) @classmethod def no_convert(cls, obj_type): @@ -205,13 +232,12 @@ def convert_dtype(cls, spec, value, spec_dtype=None): # noqa: C901 spec_dtype_type = cls.__dtypes[spec_dtype] warning_msg = None # Numpy Array or Zarr array - if (isinstance(value, np.ndarray) or - (hasattr(value, 'astype') and hasattr(value, 'dtype'))): + if isinstance(value, np.ndarray) or (hasattr(value, "astype") and hasattr(value, "dtype")): if spec_dtype_type is _unicode: - ret = value.astype('U') + ret = value.astype("U") ret_dtype = "utf8" elif spec_dtype_type is _ascii: - ret = value.astype('S') + ret = value.astype("S") ret_dtype = "ascii" else: dtype_func, warning_msg = cls.__resolve_numeric_dtype(value.dtype, spec_dtype_type) @@ -224,9 +250,9 @@ def convert_dtype(cls, spec, value, spec_dtype=None): # noqa: C901 elif isinstance(value, (tuple, list)): if len(value) == 0: if spec_dtype_type is _unicode: - ret_dtype = 'utf8' + ret_dtype = "utf8" elif spec_dtype_type is _ascii: - ret_dtype = 'ascii' + ret_dtype = "ascii" else: ret_dtype = spec_dtype_type return value, ret_dtype @@ -247,9 +273,9 @@ def convert_dtype(cls, spec, value, spec_dtype=None): # noqa: C901 ret_dtype, warning_msg = cls.__resolve_numeric_dtype(value.dtype, spec_dtype_type) else: if spec_dtype_type in (_unicode, _ascii): - ret_dtype = 'ascii' + ret_dtype = "ascii" if spec_dtype_type is _unicode: - ret_dtype = 'utf8' + ret_dtype = "utf8" ret = spec_dtype_type(value) else: dtype_func, warning_msg = cls.__resolve_numeric_dtype(type(value), spec_dtype_type) @@ -264,9 +290,11 @@ def convert_dtype(cls, spec, value, spec_dtype=None): # noqa: C901 def __check_convert_numeric(cls, value_type): # dtype 'numeric' allows only ints, floats, and uints value_dtype = np.dtype(value_type) - if not (np.issubdtype(value_dtype, np.unsignedinteger) or - np.issubdtype(value_dtype, np.floating) or - np.issubdtype(value_dtype, np.integer)): + if not ( + np.issubdtype(value_dtype, np.unsignedinteger) + or np.issubdtype(value_dtype, np.floating) + or np.issubdtype(value_dtype, np.integer) + ): raise ValueError("Cannot convert from %s to 'numeric' specification dtype." % value_type) @classmethod # noqa: C901 @@ -292,19 +320,19 @@ def __check_edgecases(cls, spec, value, spec_dtype): # noqa: C901 else: # Determine the dtype from the DataIO.data return value, cls.convert_dtype(spec, value.data, spec_dtype)[1] - if spec_dtype is None or spec_dtype == 'numeric' or type(value) in cls.__no_convert: + if spec_dtype is None or spec_dtype == "numeric" or type(value) in cls.__no_convert: # infer type from value - if hasattr(value, 'dtype'): # covers numpy types, Zarr Array, AbstractDataChunkIterator - if spec_dtype == 'numeric': + if hasattr(value, "dtype"): # covers numpy types, Zarr Array, AbstractDataChunkIterator + if spec_dtype == "numeric": cls.__check_convert_numeric(value.dtype.type) if np.issubdtype(value.dtype, np.str_): - ret_dtype = 'utf8' + ret_dtype = "utf8" elif np.issubdtype(value.dtype, np.string_): - ret_dtype = 'ascii' - elif np.issubdtype(value.dtype, np.dtype('O')): + ret_dtype = "ascii" + elif np.issubdtype(value.dtype, np.dtype("O")): # Only variable-length strings should ever appear as generic objects. # Everything else should have a well-defined type - ret_dtype = 'utf8' + ret_dtype = "utf8" else: ret_dtype = value.dtype.type return value, ret_dtype @@ -312,14 +340,17 @@ def __check_edgecases(cls, spec, value, spec_dtype): # noqa: C901 if len(value) == 0: msg = "Cannot infer dtype of empty list or tuple. Please use numpy array with specified dtype." raise ValueError(msg) - return value, cls.__check_edgecases(spec, value[0], spec_dtype)[1] # infer dtype from first element + return ( + value, + cls.__check_edgecases(spec, value[0], spec_dtype)[1], + ) # infer dtype from first element ret_dtype = type(value) - if spec_dtype == 'numeric': + if spec_dtype == "numeric": cls.__check_convert_numeric(ret_dtype) if ret_dtype is str: - ret_dtype = 'utf8' + ret_dtype = "utf8" elif ret_dtype is bytes: - ret_dtype = 'ascii' + ret_dtype = "ascii" return value, ret_dtype if isinstance(spec_dtype, RefSpec): if not isinstance(value, ReferenceBuilder): @@ -331,37 +362,41 @@ def __check_edgecases(cls, spec, value, spec_dtype): # noqa: C901 raise ValueError(msg) return None, None - _const_arg = '__constructor_arg' + _const_arg = "__constructor_arg" @staticmethod - @docval({'name': 'name', 'type': str, 'doc': 'the name of the constructor argument'}, - is_method=False) + @docval( + {"name": "name", "type": str, "doc": "the name of the constructor argument"}, + is_method=False, + ) def constructor_arg(**kwargs): - '''Decorator to override the default mapping scheme for a given constructor argument. + """Decorator to override the default mapping scheme for a given constructor argument. Decorate ObjectMapper methods with this function when extending ObjectMapper to override the default scheme for mapping between AbstractContainer and Builder objects. The decorated method should accept as its first argument the Builder object that is being mapped. The method should return the value to be passed to the target AbstractContainer class constructor argument given by *name*. - ''' - name = getargs('name', kwargs) + """ + name = getargs("name", kwargs) return _constructor_arg(name) - _obj_attr = '__object_attr' + _obj_attr = "__object_attr" @staticmethod - @docval({'name': 'name', 'type': str, 'doc': 'the name of the constructor argument'}, - is_method=False) + @docval( + {"name": "name", "type": str, "doc": "the name of the constructor argument"}, + is_method=False, + ) def object_attr(**kwargs): - '''Decorator to override the default mapping scheme for a given object attribute. + """Decorator to override the default mapping scheme for a given object attribute. Decorate ObjectMapper methods with this function when extending ObjectMapper to override the default scheme for mapping between AbstractContainer and Builder objects. The decorated method should accept as its first argument the AbstractContainer object that is being mapped. The method should return the child Builder object (or scalar if the object attribute corresponds to an AttributeSpec) that represents the attribute given by *name*. - ''' - name = getargs('name', kwargs) + """ + name = getargs("name", kwargs) return _object_attr(name) @staticmethod @@ -382,11 +417,11 @@ def __get_cargname(attr_val): @ExtenderMeta.post_init def __gather_procedures(cls, name, bases, classdict): - if hasattr(cls, 'constructor_args'): + if hasattr(cls, "constructor_args"): cls.constructor_args = copy(cls.constructor_args) else: cls.constructor_args = dict() - if hasattr(cls, 'obj_attrs'): + if hasattr(cls, "obj_attrs"): cls.obj_attrs = copy(cls.obj_attrs) else: cls.obj_attrs = dict() @@ -396,12 +431,13 @@ def __gather_procedures(cls, name, bases, classdict): elif cls.__is_attr(func): cls.obj_attrs[cls.__get_obj_attr(func)] = getattr(cls, name) - @docval({'name': 'spec', 'type': (DatasetSpec, GroupSpec), - 'doc': 'The specification for mapping objects to builders'}) + @docval( + {"name": "spec", "type": (DatasetSpec, GroupSpec), "doc": "The specification for mapping objects to builders"} + ) def __init__(self, **kwargs): - """ Create a map from AbstractContainer attributes to specifications """ - self.logger = logging.getLogger('%s.%s' % (self.__class__.__module__, self.__class__.__qualname__)) - spec = getargs('spec', kwargs) + """Create a map from AbstractContainer attributes to specifications""" + self.logger = logging.getLogger("%s.%s" % (self.__class__.__module__, self.__class__.__qualname__)) + spec = getargs("spec", kwargs) self.__spec = spec self.__data_type_key = spec.type_key() self.__spec2attr = dict() @@ -412,24 +448,24 @@ def __init__(self, **kwargs): @property def spec(self): - ''' the Spec used in this ObjectMapper ''' + """the Spec used in this ObjectMapper""" return self.__spec - @_constructor_arg('name') + @_constructor_arg("name") def get_container_name(self, *args): builder = args[0] return builder.name @classmethod - @docval({'name': 'spec', 'type': Spec, 'doc': 'the specification to get the name for'}) + @docval({"name": "spec", "type": Spec, "doc": "the specification to get the name for"}) def convert_dt_name(cls, **kwargs): - '''Construct the attribute name corresponding to a specification''' - spec = getargs('spec', kwargs) + """Construct the attribute name corresponding to a specification""" + spec = getargs("spec", kwargs) name = cls.__get_data_type(spec) - s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name) - name = re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower() - if name[-1] != 's' and spec.is_many(): - name += 's' + s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name) + name = re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower() + if name[-1] != "s" and spec.is_many(): + name += "s" return name @classmethod @@ -438,7 +474,7 @@ def __get_fields(cls, name_stack, all_names, spec): if spec.name is None: name = cls.convert_dt_name(spec) name_stack.append(name) - name = '__'.join(name_stack) + name = "__".join(name_stack) # TODO address potential name clashes, e.g., quantity '*' subgroups and links of same data_type_inc will # have the same name all_names[name] = spec @@ -459,10 +495,10 @@ def __get_fields(cls, name_stack, all_names, spec): name_stack.pop() @classmethod - @docval({'name': 'spec', 'type': Spec, 'doc': 'the specification to get the object attribute names for'}) + @docval({"name": "spec", "type": Spec, "doc": "the specification to get the object attribute names for"}) def get_attr_names(cls, **kwargs): - '''Get the attribute names for each subspecification in a Spec''' - spec = getargs('spec', kwargs) + """Get the attribute names for each subspecification in a Spec""" + spec = getargs("spec", kwargs) names = OrderedDict() for subspec in spec.attributes: cls.__get_fields(list(), names, subspec) @@ -480,46 +516,52 @@ def __map_spec(self, spec): for k, v in attr_names.items(): self.map_spec(k, v) - @docval({"name": "attr_name", "type": str, "doc": "the name of the object to map"}, - {"name": "spec", "type": Spec, "doc": "the spec to map the attribute to"}) + @docval( + {"name": "attr_name", "type": str, "doc": "the name of the object to map"}, + {"name": "spec", "type": Spec, "doc": "the spec to map the attribute to"}, + ) def map_attr(self, **kwargs): - """ Map an attribute to spec. Use this to override default behavior """ - attr_name, spec = getargs('attr_name', 'spec', kwargs) + """Map an attribute to spec. Use this to override default behavior""" + attr_name, spec = getargs("attr_name", "spec", kwargs) self.__spec2attr[spec] = attr_name self.__attr2spec[attr_name] = spec @docval({"name": "attr_name", "type": str, "doc": "the name of the attribute"}) def get_attr_spec(self, **kwargs): - """ Return the Spec for a given attribute """ - attr_name = getargs('attr_name', kwargs) + """Return the Spec for a given attribute""" + attr_name = getargs("attr_name", kwargs) return self.__attr2spec.get(attr_name) @docval({"name": "carg_name", "type": str, "doc": "the name of the constructor argument"}) def get_carg_spec(self, **kwargs): - """ Return the Spec for a given constructor argument """ - carg_name = getargs('carg_name', kwargs) + """Return the Spec for a given constructor argument""" + carg_name = getargs("carg_name", kwargs) return self.__carg2spec.get(carg_name) - @docval({"name": "const_arg", "type": str, "doc": "the name of the constructor argument to map"}, - {"name": "spec", "type": Spec, "doc": "the spec to map the attribute to"}) + @docval( + {"name": "const_arg", "type": str, "doc": "the name of the constructor argument to map"}, + {"name": "spec", "type": Spec, "doc": "the spec to map the attribute to"}, + ) def map_const_arg(self, **kwargs): - """ Map an attribute to spec. Use this to override default behavior """ - const_arg, spec = getargs('const_arg', 'spec', kwargs) + """Map an attribute to spec. Use this to override default behavior""" + const_arg, spec = getargs("const_arg", "spec", kwargs) self.__spec2carg[spec] = const_arg self.__carg2spec[const_arg] = spec @docval({"name": "spec", "type": Spec, "doc": "the spec to map the attribute to"}) def unmap(self, **kwargs): - """ Removing any mapping for a specification. Use this to override default mapping """ - spec = getargs('spec', kwargs) + """Removing any mapping for a specification. Use this to override default mapping""" + spec = getargs("spec", kwargs) self.__spec2attr.pop(spec, None) self.__spec2carg.pop(spec, None) - @docval({"name": "attr_carg", "type": str, "doc": "the constructor argument/object attribute to map this spec to"}, - {"name": "spec", "type": Spec, "doc": "the spec to map the attribute to"}) + @docval( + {"name": "attr_carg", "type": str, "doc": "the constructor argument/object attribute to map this spec to"}, + {"name": "spec", "type": Spec, "doc": "the spec to map the attribute to"}, + ) def map_spec(self, **kwargs): - """ Map the given specification to the construct argument and object attribute """ - spec, attr_carg = getargs('spec', 'attr_carg', kwargs) + """Map the given specification to the construct argument and object attribute""" + spec, attr_carg = getargs("spec", "attr_carg", kwargs) self.map_const_arg(attr_carg, spec) self.map_attr(attr_carg, spec) @@ -539,21 +581,26 @@ def __get_override_attr(self, name, container, manager): return func(self, container, manager) return None - @docval({"name": "spec", "type": Spec, "doc": "the spec to get the attribute for"}, - returns='the attribute name', rtype=str) + @docval( + {"name": "spec", "type": Spec, "doc": "the spec to get the attribute for"}, + returns="the attribute name", + rtype=str, + ) def get_attribute(self, **kwargs): - ''' Get the object attribute name for the given Spec ''' - spec = getargs('spec', kwargs) + """Get the object attribute name for the given Spec""" + spec = getargs("spec", kwargs) val = self.__spec2attr.get(spec, None) return val - @docval({"name": "spec", "type": Spec, "doc": "the spec to get the attribute value for"}, - {"name": "container", "type": AbstractContainer, "doc": "the container to get the attribute value from"}, - {"name": "manager", "type": BuildManager, "doc": "the BuildManager used for managing this build"}, - returns='the value of the attribute') + @docval( + {"name": "spec", "type": Spec, "doc": "the spec to get the attribute value for"}, + {"name": "container", "type": AbstractContainer, "doc": "the container to get the attribute value from"}, + {"name": "manager", "type": BuildManager, "doc": "the BuildManager used for managing this build"}, + returns="the value of the attribute", + ) def get_attr_value(self, **kwargs): - ''' Get the value of the attribute corresponding to this spec from the given container ''' - spec, container, manager = getargs('spec', 'container', 'manager', kwargs) + """Get the value of the attribute corresponding to this spec from the given container""" + spec, container, manager = getargs("spec", "container", "manager", kwargs) attr_name = self.get_attribute(spec) if attr_name is None: return None @@ -562,8 +609,12 @@ def get_attr_value(self, **kwargs): try: attr_val = getattr(container, attr_name) except AttributeError: - msg = ("%s '%s' does not have attribute '%s' for mapping to spec: %s" - % (container.__class__.__name__, container.name, attr_name, spec)) + msg = "%s '%s' does not have attribute '%s' for mapping to spec: %s" % ( + container.__class__.__name__, + container.name, + attr_name, + spec, + ) raise ContainerConfigurationError(msg) if attr_val is not None: attr_val = self.__convert_string(attr_val, spec) @@ -572,8 +623,11 @@ def get_attr_value(self, **kwargs): try: attr_val = self.__filter_by_spec_dt(attr_val, spec_dt, manager) except ValueError as e: - msg = ("%s '%s' attribute '%s' has unexpected type." - % (container.__class__.__name__, container.name, attr_name)) + msg = "%s '%s' attribute '%s' has unexpected type." % ( + container.__class__.__name__, + container.name, + attr_name, + ) raise ContainerConfigurationError(msg) from e # else: attr_val is an attribute on the Container and its value is None # attr_val can be None, an AbstractContainer, or a list of AbstractContainers @@ -597,7 +651,7 @@ def __convert_string(self, value, spec): """Convert string types to the specified dtype.""" ret = value if isinstance(spec, AttributeSpec): - if 'text' in spec.dtype: + if "text" in spec.dtype: if spec.shape is not None or spec.dims is not None: ret = list(map(str, value)) else: @@ -606,11 +660,11 @@ def __convert_string(self, value, spec): # TODO: make sure we can handle specs with data_type_inc set if spec.data_type_inc is None and spec.dtype is not None: string_type = None - if 'text' in spec.dtype: + if "text" in spec.dtype: string_type = str - elif 'ascii' in spec.dtype: + elif "ascii" in spec.dtype: string_type = bytes - elif 'isodatetime' in spec.dtype: + elif "isodatetime" in spec.dtype: string_type = datetime.isoformat if string_type is not None: if spec.shape is not None or spec.dims is not None: @@ -620,7 +674,7 @@ def __convert_string(self, value, spec): # copy over any I/O parameters if they were specified if isinstance(value, DataIO): params = value.get_io_params() - params['data'] = ret + params["data"] = ret ret = value.__class__(**params) return ret @@ -656,58 +710,77 @@ def __filter_by_spec_dt(self, attr_value, spec_dt, build_manager): if len(ret) == 0: ret = None else: - raise ValueError("Unexpected type for attr_value: %s. Only AbstractContainer, list, tuple, set, dict, are " - "allowed." % type(attr_value)) + raise ValueError( + "Unexpected type for attr_value: %s. Only AbstractContainer, list, tuple, set, dict, are allowed." + % type(attr_value) + ) return ret def __check_quantity(self, attr_value, spec, container): if attr_value is None and spec.required: attr_name = self.get_attribute(spec) - msg = ("%s '%s' is missing required value for attribute '%s'." - % (container.__class__.__name__, container.name, attr_name)) + msg = "%s '%s' is missing required value for attribute '%s'." % ( + container.__class__.__name__, + container.name, + attr_name, + ) warnings.warn(msg, MissingRequiredBuildWarning) - self.logger.debug('MissingRequiredBuildWarning: ' + msg) + self.logger.debug("MissingRequiredBuildWarning: " + msg) elif attr_value is not None and self.__get_data_type(spec) is not None: # quantity is valid only for specs with a data type or target type if isinstance(attr_value, AbstractContainer): attr_value = [attr_value] n = len(attr_value) - if (n and isinstance(attr_value[0], AbstractContainer) and - ((n > 1 and not spec.is_many()) or (isinstance(spec.quantity, int) and n != spec.quantity))): + if ( + n + and isinstance(attr_value[0], AbstractContainer) + and ((n > 1 and not spec.is_many()) or (isinstance(spec.quantity, int) and n != spec.quantity)) + ): attr_name = self.get_attribute(spec) - msg = ("%s '%s' has %d values for attribute '%s' but spec allows %s." - % (container.__class__.__name__, container.name, n, attr_name, repr(spec.quantity))) + msg = "%s '%s' has %d values for attribute '%s' but spec allows %s." % ( + container.__class__.__name__, + container.name, + n, + attr_name, + repr(spec.quantity), + ) warnings.warn(msg, IncorrectQuantityBuildWarning) - self.logger.debug('IncorrectQuantityBuildWarning: ' + msg) + self.logger.debug("IncorrectQuantityBuildWarning: " + msg) - @docval({"name": "spec", "type": Spec, "doc": "the spec to get the constructor argument for"}, - returns="the name of the constructor argument", rtype=str) + @docval( + {"name": "spec", "type": Spec, "doc": "the spec to get the constructor argument for"}, + returns="the name of the constructor argument", + rtype=str, + ) def get_const_arg(self, **kwargs): - ''' Get the constructor argument for the given Spec ''' - spec = getargs('spec', kwargs) + """Get the constructor argument for the given Spec""" + spec = getargs("spec", kwargs) return self.__spec2carg.get(spec, None) - @docval({"name": "container", "type": AbstractContainer, "doc": "the container to convert to a Builder"}, - {"name": "manager", "type": BuildManager, "doc": "the BuildManager to use for managing this build"}, - {"name": "parent", "type": GroupBuilder, "doc": "the parent of the resulting Builder", 'default': None}, - {"name": "source", "type": str, - "doc": "the source of container being built i.e. file path", 'default': None}, - {"name": "builder", "type": BaseBuilder, "doc": "the Builder to build on", 'default': None}, - {"name": "spec_ext", "type": BaseStorageSpec, "doc": "a spec extension", 'default': None}, - {"name": "export", "type": bool, "doc": "whether this build is for exporting", - 'default': False}, - returns="the Builder representing the given AbstractContainer", rtype=Builder) + @docval( + {"name": "container", "type": AbstractContainer, "doc": "the container to convert to a Builder"}, + {"name": "manager", "type": BuildManager, "doc": "the BuildManager to use for managing this build"}, + {"name": "parent", "type": GroupBuilder, "doc": "the parent of the resulting Builder", "default": None}, + {"name": "source", "type": str, "doc": "the source of container being built i.e. file path", "default": None}, + {"name": "builder", "type": BaseBuilder, "doc": "the Builder to build on", "default": None}, + {"name": "spec_ext", "type": BaseStorageSpec, "doc": "a spec extension", "default": None}, + {"name": "export", "type": bool, "doc": "whether this build is for exporting", "default": False}, + returns="the Builder representing the given AbstractContainer", + rtype=Builder, + ) def build(self, **kwargs): - '''Convert an AbstractContainer to a Builder representation. + """Convert an AbstractContainer to a Builder representation. References are not added but are queued to be added in the BuildManager. - ''' - container, manager, parent, source = getargs('container', 'manager', 'parent', 'source', kwargs) - builder, spec_ext, export = getargs('builder', 'spec_ext', 'export', kwargs) + """ + container, manager, parent, source = getargs("container", "manager", "parent", "source", kwargs) + builder, spec_ext, export = getargs("builder", "spec_ext", "export", kwargs) name = manager.get_builder_name(container) if isinstance(self.__spec, GroupSpec): - self.logger.debug("Building %s '%s' as a group (source: %s)" - % (container.__class__.__name__, container.name, repr(source))) + self.logger.debug( + "Building %s '%s' as a group (source: %s)" + % (container.__class__.__name__, container.name, repr(source)) + ) if builder is None: builder = GroupBuilder(name, parent=parent, source=source) self.__add_datasets(builder, self.__spec.datasets, container, manager, source, export) @@ -720,43 +793,57 @@ def build(self, **kwargs): raise ValueError(msg) spec_dtype, spec_shape, spec = self.__check_dset_spec(self.spec, spec_ext) if isinstance(spec_dtype, RefSpec): - self.logger.debug("Building %s '%s' as a dataset of references (source: %s)" - % (container.__class__.__name__, container.name, repr(source))) + self.logger.debug( + "Building %s '%s' as a dataset of references (source: %s)" + % (container.__class__.__name__, container.name, repr(source)) + ) # create dataset builder with data=None as a placeholder. fill in with refs later - builder = DatasetBuilder(name, data=None, parent=parent, source=source, dtype=spec_dtype.reftype) + builder = DatasetBuilder( + name, + data=None, + parent=parent, + source=source, + dtype=spec_dtype.reftype, + ) manager.queue_ref(self.__set_dataset_to_refs(builder, spec_dtype, spec_shape, container, manager)) elif isinstance(spec_dtype, list): # a compound dataset - self.logger.debug("Building %s '%s' as a dataset of compound dtypes (source: %s)" - % (container.__class__.__name__, container.name, repr(source))) + self.logger.debug( + "Building %s '%s' as a dataset of compound dtypes (source: %s)" + % (container.__class__.__name__, container.name, repr(source)) + ) # create dataset builder with data=None, dtype=None as a placeholder. fill in with refs later builder = DatasetBuilder(name, data=None, parent=parent, source=source, dtype=spec_dtype) - manager.queue_ref(self.__set_compound_dataset_to_refs(builder, spec, spec_dtype, container, - manager)) + manager.queue_ref( + self.__set_compound_dataset_to_refs(builder, spec, spec_dtype, container, manager) + ) else: # a regular dtype if spec_dtype is None and self.__is_reftype(container.data): - self.logger.debug("Building %s '%s' containing references as a dataset of unspecified dtype " - "(source: %s)" - % (container.__class__.__name__, container.name, repr(source))) + self.logger.debug( + "Building %s '%s' containing references as a dataset of unspecified dtype (source: %s)" + % (container.__class__.__name__, container.name, repr(source)) + ) # an unspecified dtype and we were given references # create dataset builder with data=None as a placeholder. fill in with refs later - builder = DatasetBuilder(name, data=None, parent=parent, source=source, dtype='object') + builder = DatasetBuilder(name, data=None, parent=parent, source=source, dtype="object") manager.queue_ref(self.__set_untyped_dataset_to_refs(builder, container, manager)) else: # a dataset that has no references, pass the conversion off to the convert_dtype method - self.logger.debug("Building %s '%s' as a dataset (source: %s)" - % (container.__class__.__name__, container.name, repr(source))) + self.logger.debug( + "Building %s '%s' as a dataset (source: %s)" + % (container.__class__.__name__, container.name, repr(source)) + ) try: # use spec_dtype from self.spec when spec_ext does not specify dtype bldr_data, dtype = self.convert_dtype(spec, container.data, spec_dtype=spec_dtype) except Exception as ex: - msg = 'could not resolve dtype for %s \'%s\'' % (type(container).__name__, container.name) + msg = "could not resolve dtype for %s '%s'" % (type(container).__name__, container.name) raise Exception(msg) from ex builder = DatasetBuilder(name, bldr_data, parent=parent, source=source, dtype=dtype) # Add attributes from the specification extension to the list of attributes - all_attrs = self.__spec.attributes + getattr(spec_ext, 'attributes', tuple()) + all_attrs = self.__spec.attributes + getattr(spec_ext, "attributes", tuple()) # If the spec_ext refines an existing attribute it will now appear twice in the list. The # refinement should only be relevant for validation (not for write). To avoid problems with the # write we here remove duplicates and keep the original spec of the two to make write work. @@ -782,18 +869,19 @@ def __check_dset_spec(self, orig, ext): return dtype, shape, spec def __is_reftype(self, data): - if (isinstance(data, AbstractDataChunkIterator) or - (isinstance(data, DataIO) and isinstance(data.data, AbstractDataChunkIterator))): + if isinstance(data, AbstractDataChunkIterator) or ( + isinstance(data, DataIO) and isinstance(data.data, AbstractDataChunkIterator) + ): return False tmp = data - while hasattr(tmp, '__len__') and not isinstance(tmp, (AbstractContainer, str, bytes)): + while hasattr(tmp, "__len__") and not isinstance(tmp, (AbstractContainer, str, bytes)): tmptmp = None for t in tmp: # In case of a numeric array stop the iteration at the first element to avoid long-running loop if isinstance(t, (int, float, complex, bool)): break - if hasattr(t, '__len__') and len(t) > 0 and not isinstance(t, (AbstractContainer, str, bytes)): + if hasattr(t, "__len__") and len(t) > 0 and not isinstance(t, (AbstractContainer, str, bytes)): tmptmp = tmp[0] break if tmptmp is not None: @@ -809,8 +897,10 @@ def __is_reftype(self, data): return False def __set_dataset_to_refs(self, builder, dtype, shape, container, build_manager): - self.logger.debug("Queueing set dataset of references %s '%s' to reference builder(s)" - % (builder.__class__.__name__, builder.name)) + self.logger.debug( + "Queueing set dataset of references %s '%s' to reference builder(s)" + % (builder.__class__.__name__, builder.name) + ) def _filler(): builder.data = self.__get_ref_builder(builder, dtype, shape, container, build_manager) @@ -818,12 +908,16 @@ def _filler(): return _filler def __set_compound_dataset_to_refs(self, builder, spec, spec_dtype, container, build_manager): - self.logger.debug("Queueing convert compound dataset %s '%s' and set any references to reference builders" - % (builder.__class__.__name__, builder.name)) + self.logger.debug( + "Queueing convert compound dataset %s '%s' and set any references to reference builders" + % (builder.__class__.__name__, builder.name) + ) def _filler(): - self.logger.debug("Converting compound dataset %s '%s' and setting any references to reference builders" - % (builder.__class__.__name__, builder.name)) + self.logger.debug( + "Converting compound dataset %s '%s' and setting any references to reference builders" + % (builder.__class__.__name__, builder.name) + ) # convert the reference part(s) of a compound dataset to ReferenceBuilders, row by row refs = [(i, subt) for i, subt in enumerate(spec_dtype) if isinstance(subt.dtype, RefSpec)] bldr_data = list() @@ -837,12 +931,15 @@ def _filler(): return _filler def __set_untyped_dataset_to_refs(self, builder, container, build_manager): - self.logger.debug("Queueing set untyped dataset %s '%s' to reference builders" - % (builder.__class__.__name__, builder.name)) + self.logger.debug( + "Queueing set untyped dataset %s '%s' to reference builders" % (builder.__class__.__name__, builder.name) + ) def _filler(): - self.logger.debug("Setting untyped dataset %s '%s' to list of reference builders" - % (builder.__class__.__name__, builder.name)) + self.logger.debug( + "Setting untyped dataset %s '%s' to list of reference builders" + % (builder.__class__.__name__, builder.name) + ) bldr_data = list() for d in container.data: if d is None: @@ -861,30 +958,36 @@ def __get_ref_builder(self, builder, dtype, shape, container, build_manager): if not isinstance(container, DataRegion): msg = "'container' must be of type DataRegion if spec represents region reference" raise ValueError(msg) - self.logger.debug("Setting %s '%s' data to region reference builder" - % (builder.__class__.__name__, builder.name)) + self.logger.debug( + "Setting %s '%s' data to region reference builder" % (builder.__class__.__name__, builder.name) + ) target_builder = self.__get_target_builder(container.data, build_manager, builder) bldr_data = RegionBuilder(container.region, target_builder) else: - self.logger.debug("Setting %s '%s' data to list of region reference builders" - % (builder.__class__.__name__, builder.name)) + self.logger.debug( + "Setting %s '%s' data to list of region reference builders" + % (builder.__class__.__name__, builder.name) + ) bldr_data = list() for d in container.data: target_builder = self.__get_target_builder(d.target, build_manager, builder) bldr_data.append(RegionBuilder(d.slice, target_builder)) else: - self.logger.debug("Setting object reference dataset on %s '%s' data" - % (builder.__class__.__name__, builder.name)) + self.logger.debug( + "Setting object reference dataset on %s '%s' data" % (builder.__class__.__name__, builder.name) + ) if isinstance(container, Data): - self.logger.debug("Setting %s '%s' data to list of reference builders" - % (builder.__class__.__name__, builder.name)) + self.logger.debug( + "Setting %s '%s' data to list of reference builders" % (builder.__class__.__name__, builder.name) + ) bldr_data = list() for d in container.data: target_builder = self.__get_target_builder(d, build_manager, builder) bldr_data.append(ReferenceBuilder(target_builder)) else: - self.logger.debug("Setting %s '%s' data to reference builder" - % (builder.__class__.__name__, builder.name)) + self.logger.debug( + "Setting %s '%s' data to reference builder" % (builder.__class__.__name__, builder.name) + ) target_builder = self.__get_target_builder(container, build_manager, builder) bldr_data = ReferenceBuilder(target_builder) return bldr_data @@ -897,12 +1000,19 @@ def __get_target_builder(self, container, build_manager, builder): def __add_attributes(self, builder, attributes, container, build_manager, source, export): if attributes: - self.logger.debug("Adding attributes from %s '%s' to %s '%s'" - % (container.__class__.__name__, container.name, - builder.__class__.__name__, builder.name)) + self.logger.debug( + "Adding attributes from %s '%s' to %s '%s'" + % ( + container.__class__.__name__, + container.name, + builder.__class__.__name__, + builder.name, + ) + ) for spec in attributes: - self.logger.debug(" Adding attribute for spec name: %s (dtype: %s)" - % (repr(spec.name), spec.dtype.__class__.__name__)) + self.logger.debug( + " Adding attribute for spec name: %s (dtype: %s)" % (repr(spec.name), spec.dtype.__class__.__name__) + ) if spec.value is not None: attr_value = spec.value else: @@ -919,8 +1029,10 @@ def __add_attributes(self, builder, attributes, container, build_manager, source if isinstance(spec.dtype, RefSpec): if not self.__is_reftype(attr_value): - msg = ("invalid type for reference '%s' (%s) - must be AbstractContainer" - % (spec.name, type(attr_value))) + msg = "invalid type for reference '%s' (%s) - must be AbstractContainer" % ( + spec.name, + type(attr_value), + ) raise ValueError(msg) build_manager.queue_ref(self.__set_attr_to_ref(builder, attr_value, build_manager, spec)) @@ -929,7 +1041,7 @@ def __add_attributes(self, builder, attributes, container, build_manager, source try: attr_value, attr_dtype = self.convert_dtype(spec, attr_value) except Exception as ex: - msg = 'could not convert %s for %s %s' % (spec.name, type(container).__name__, container.name) + msg = "could not convert %s for %s %s" % (spec.name, type(container).__name__, container.name) raise BuildError(builder, msg) from ex # do not write empty or null valued objects @@ -941,14 +1053,16 @@ def __add_attributes(self, builder, attributes, container, build_manager, source builder.set_attribute(spec.name, attr_value) def __set_attr_to_ref(self, builder, attr_value, build_manager, spec): - self.logger.debug("Queueing set reference attribute on %s '%s' attribute '%s' to %s" - % (builder.__class__.__name__, builder.name, spec.name, - attr_value.__class__.__name__)) + self.logger.debug( + "Queueing set reference attribute on %s '%s' attribute '%s' to %s" + % (builder.__class__.__name__, builder.name, spec.name, attr_value.__class__.__name__) + ) def _filler(): - self.logger.debug("Setting reference attribute on %s '%s' attribute '%s' to %s" - % (builder.__class__.__name__, builder.name, spec.name, - attr_value.__class__.__name__)) + self.logger.debug( + "Setting reference attribute on %s '%s' attribute '%s' to %s" + % (builder.__class__.__name__, builder.name, spec.name, attr_value.__class__.__name__) + ) target_builder = self.__get_target_builder(attr_value, build_manager, builder) ref_attr_value = ReferenceBuilder(target_builder) builder.set_attribute(spec.name, ref_attr_value) @@ -957,12 +1071,14 @@ def _filler(): def __add_links(self, builder, links, container, build_manager, source, export): if links: - self.logger.debug("Adding links from %s '%s' to %s '%s'" - % (container.__class__.__name__, container.name, - builder.__class__.__name__, builder.name)) + self.logger.debug( + "Adding links from %s '%s' to %s '%s'" + % (container.__class__.__name__, container.name, builder.__class__.__name__, builder.name) + ) for spec in links: - self.logger.debug(" Adding link for spec name: %s, target_type: %s" - % (repr(spec.name), repr(spec.target_type))) + self.logger.debug( + " Adding link for spec name: %s, target_type: %s" % (repr(spec.name), repr(spec.target_type)) + ) attr_value = self.get_attr_value(spec, container, build_manager) self.__check_quantity(attr_value, spec, container) if attr_value is None: @@ -972,12 +1088,14 @@ def __add_links(self, builder, links, container, build_manager, source, export): def __add_datasets(self, builder, datasets, container, build_manager, source, export): if datasets: - self.logger.debug("Adding datasets from %s '%s' to %s '%s'" - % (container.__class__.__name__, container.name, - builder.__class__.__name__, builder.name)) + self.logger.debug( + "Adding datasets from %s '%s' to %s '%s'" + % (container.__class__.__name__, container.name, builder.__class__.__name__, builder.name) + ) for spec in datasets: - self.logger.debug(" Adding dataset for spec name: %s (dtype: %s)" - % (repr(spec.name), spec.dtype.__class__.__name__)) + self.logger.debug( + " Adding dataset for spec name: %s (dtype: %s)" % (repr(spec.name), spec.dtype.__class__.__name__) + ) attr_value = self.get_attr_value(spec, container, build_manager) self.__check_quantity(attr_value, spec, container) if attr_value is None: @@ -985,43 +1103,69 @@ def __add_datasets(self, builder, datasets, container, build_manager, source, ex continue attr_value = self.__check_ref_resolver(attr_value) if isinstance(attr_value, LinkBuilder): - self.logger.debug(" Adding %s '%s' for spec name: %s, %s: %s, %s: %s" - % (attr_value.name, attr_value.__class__.__name__, - repr(spec.name), - spec.def_key(), repr(spec.data_type_def), - spec.inc_key(), repr(spec.data_type_inc))) + self.logger.debug( + " Adding %s '%s' for spec name: %s, %s: %s, %s: %s" + % ( + attr_value.name, + attr_value.__class__.__name__, + repr(spec.name), + spec.def_key(), + repr(spec.data_type_def), + spec.inc_key(), + repr(spec.data_type_inc), + ) + ) builder.set_link(attr_value) # add the existing builder elif spec.data_type_def is None and spec.data_type_inc is None: # untyped, named dataset if spec.name in builder.datasets: sub_builder = builder.datasets[spec.name] - self.logger.debug(" Retrieving existing DatasetBuilder '%s' for spec name %s and adding " - "attributes" % (sub_builder.name, repr(spec.name))) + self.logger.debug( + " Retrieving existing DatasetBuilder '%s' for spec name %s and adding attributes" + % (sub_builder.name, repr(spec.name)) + ) else: - self.logger.debug(" Converting untyped dataset for spec name %s to spec dtype %s" - % (repr(spec.name), repr(spec.dtype))) + self.logger.debug( + " Converting untyped dataset for spec name %s to spec dtype %s" + % (repr(spec.name), repr(spec.dtype)) + ) try: data, dtype = self.convert_dtype(spec, attr_value) except Exception as ex: - msg = 'could not convert \'%s\' for %s \'%s\'' + msg = "could not convert '%s' for %s '%s'" msg = msg % (spec.name, type(container).__name__, container.name) raise BuildError(builder, msg) from ex - self.logger.debug(" Adding untyped dataset for spec name %s and adding attributes" - % repr(spec.name)) + self.logger.debug( + " Adding untyped dataset for spec name %s and adding attributes" % repr(spec.name) + ) sub_builder = DatasetBuilder(spec.name, data, parent=builder, source=source, dtype=dtype) builder.set_dataset(sub_builder) - self.__add_attributes(sub_builder, spec.attributes, container, build_manager, source, export) + self.__add_attributes( + sub_builder, + spec.attributes, + container, + build_manager, + source, + export, + ) else: - self.logger.debug(" Adding typed dataset for spec name: %s, %s: %s, %s: %s" - % (repr(spec.name), - spec.def_key(), repr(spec.data_type_def), - spec.inc_key(), repr(spec.data_type_inc))) + self.logger.debug( + " Adding typed dataset for spec name: %s, %s: %s, %s: %s" + % ( + repr(spec.name), + spec.def_key(), + repr(spec.data_type_def), + spec.inc_key(), + repr(spec.data_type_inc), + ) + ) self.__add_containers(builder, spec, attr_value, build_manager, source, container, export) def __add_groups(self, builder, groups, container, build_manager, source, export): if groups: - self.logger.debug("Adding groups from %s '%s' to %s '%s'" - % (container.__class__.__name__, container.name, - builder.__class__.__name__, builder.name)) + self.logger.debug( + "Adding groups from %s '%s' to %s '%s'" + % (container.__class__.__name__, container.name, builder.__class__.__name__, builder.name) + ) for spec in groups: if spec.data_type_def is None and spec.data_type_inc is None: self.logger.debug(" Adding untyped group for spec name: %s" % repr(spec.name)) @@ -1029,7 +1173,14 @@ def __add_groups(self, builder, groups, container, build_manager, source, export sub_builder = builder.groups.get(spec.name) if sub_builder is None: sub_builder = GroupBuilder(spec.name, source=source) - self.__add_attributes(sub_builder, spec.attributes, container, build_manager, source, export) + self.__add_attributes( + sub_builder, + spec.attributes, + container, + build_manager, + source, + export, + ) self.__add_datasets(sub_builder, spec.datasets, container, build_manager, source, export) self.__add_links(sub_builder, spec.links, container, build_manager, source, export) self.__add_groups(sub_builder, spec.groups, container, build_manager, source, export) @@ -1038,24 +1189,47 @@ def __add_groups(self, builder, groups, container, build_manager, source, export if sub_builder.name not in builder.groups: builder.set_group(sub_builder) else: - self.logger.debug(" Adding group for spec name: %s, %s: %s, %s: %s" - % (repr(spec.name), - spec.def_key(), repr(spec.data_type_def), - spec.inc_key(), repr(spec.data_type_inc))) + self.logger.debug( + " Adding group for spec name: %s, %s: %s, %s: %s" + % ( + repr(spec.name), + spec.def_key(), + repr(spec.data_type_def), + spec.inc_key(), + repr(spec.data_type_inc), + ) + ) attr_value = self.get_attr_value(spec, container, build_manager) self.__check_quantity(attr_value, spec, container) if attr_value is not None: - self.__add_containers(builder, spec, attr_value, build_manager, source, container, export) + self.__add_containers( + builder, + spec, + attr_value, + build_manager, + source, + container, + export, + ) def __add_containers(self, builder, spec, value, build_manager, source, parent_container, export): if isinstance(value, AbstractContainer): - self.logger.debug(" Adding container %s '%s' with parent %s '%s' to %s '%s'" - % (value.__class__.__name__, value.name, - parent_container.__class__.__name__, parent_container.name, - builder.__class__.__name__, builder.name)) + self.logger.debug( + " Adding container %s '%s' with parent %s '%s' to %s '%s'" + % ( + value.__class__.__name__, + value.name, + parent_container.__class__.__name__, + parent_container.name, + builder.__class__.__name__, + builder.name, + ) + ) if value.parent is None: - if (value.container_source == parent_container.container_source or - build_manager.get_builder(value) is None): + if ( + value.container_source == parent_container.container_source + or build_manager.get_builder(value) is None + ): # value was removed (or parent not set) and there is a link to it in same file # or value was read from an external link raise OrphanContainerBuildError(builder, value) @@ -1070,43 +1244,57 @@ def __add_containers(self, builder, spec, value, build_manager, source, parent_c new_builder = build_manager.build(value, source=source, export=export) # use spec to determine what kind of HDF5 object this AbstractContainer corresponds to if isinstance(spec, LinkSpec) or value.parent is not parent_container: - self.logger.debug(" Adding link to %s '%s' in %s '%s'" - % (new_builder.__class__.__name__, new_builder.name, - builder.__class__.__name__, builder.name)) + self.logger.debug( + " Adding link to %s '%s' in %s '%s'" + % (new_builder.__class__.__name__, new_builder.name, builder.__class__.__name__, builder.name) + ) builder.set_link(LinkBuilder(new_builder, name=spec.name, parent=builder)) elif isinstance(spec, DatasetSpec): - self.logger.debug(" Adding dataset %s '%s' to %s '%s'" - % (new_builder.__class__.__name__, new_builder.name, - builder.__class__.__name__, builder.name)) + self.logger.debug( + " Adding dataset %s '%s' to %s '%s'" + % (new_builder.__class__.__name__, new_builder.name, builder.__class__.__name__, builder.name) + ) builder.set_dataset(new_builder) else: - self.logger.debug(" Adding subgroup %s '%s' to %s '%s'" - % (new_builder.__class__.__name__, new_builder.name, - builder.__class__.__name__, builder.name)) + self.logger.debug( + " Adding subgroup %s '%s' to %s '%s'" + % (new_builder.__class__.__name__, new_builder.name, builder.__class__.__name__, builder.name) + ) builder.set_group(new_builder) elif value.container_source: # make a link to an existing container - if (value.container_source != parent_container.container_source - or value.parent is not parent_container): - self.logger.debug(" Building %s '%s' (container source: %s) and adding a link to it" - % (value.__class__.__name__, value.name, value.container_source)) + if value.container_source != parent_container.container_source or value.parent is not parent_container: + self.logger.debug( + " Building %s '%s' (container source: %s) and adding a link to it" + % (value.__class__.__name__, value.name, value.container_source) + ) if isinstance(spec, BaseStorageSpec): new_builder = build_manager.build(value, source=source, spec_ext=spec, export=export) else: new_builder = build_manager.build(value, source=source, export=export) builder.set_link(LinkBuilder(new_builder, name=spec.name, parent=builder)) else: - self.logger.debug(" Skipping build for %s '%s' because both it and its parents were read " - "from the same source." - % (value.__class__.__name__, value.name)) + self.logger.debug( + " Skipping build for %s '%s' because both it and its parents were read from the same source." + % (value.__class__.__name__, value.name) + ) else: - raise ValueError("Found unmodified AbstractContainer with no source - '%s' with parent '%s'" % - (value.name, parent_container.name)) + raise ValueError( + "Found unmodified AbstractContainer with no source - '%s' with parent '%s'" + % (value.name, parent_container.name) + ) elif isinstance(value, list): for container in value: - self.__add_containers(builder, spec, container, build_manager, source, parent_container, export) + self.__add_containers( + builder, + spec, + container, + build_manager, + source, + parent_container, + export, + ) else: # pragma: no cover - msg = ("Received %s, expected AbstractContainer or a list of AbstractContainers." - % value.__class__.__name__) + msg = "Received %s, expected AbstractContainer or a list of AbstractContainers." % value.__class__.__name__ raise ValueError(msg) def __get_subspec_values(self, builder, spec, manager): @@ -1157,10 +1345,9 @@ def __get_subspec_values(self, builder, spec, manager): elif isinstance(spec, DatasetSpec): if not isinstance(builder, DatasetBuilder): # pragma: no cover raise ValueError("__get_subspec_values - must pass DatasetBuilder with DatasetSpec") - if (spec.shape is None and getattr(builder.data, 'shape', None) == (1,) and - type(builder.data[0]) != np.void): + if spec.shape is None and getattr(builder.data, "shape", None) == (1,) and type(builder.data[0]) != np.void: # if a scalar dataset is expected and a 1-element non-compound dataset is given, then read the dataset - builder['data'] = builder.data[0] # use dictionary reference instead of .data to bypass error + builder["data"] = builder.data[0] # use dictionary reference instead of .data to bypass error ret[spec] = self.__check_ref_resolver(builder.data) return ret @@ -1213,14 +1400,27 @@ def __flatten(self, sub_builder, subspec, manager): tmp = tmp[0] return tmp - @docval({'name': 'builder', 'type': (DatasetBuilder, GroupBuilder), - 'doc': 'the builder to construct the AbstractContainer from'}, - {'name': 'manager', 'type': BuildManager, 'doc': 'the BuildManager for this build'}, - {'name': 'parent', 'type': (Proxy, AbstractContainer), - 'doc': 'the parent AbstractContainer/Proxy for the AbstractContainer being built', 'default': None}) + @docval( + { + "name": "builder", + "type": (DatasetBuilder, GroupBuilder), + "doc": "the builder to construct the AbstractContainer from", + }, + { + "name": "manager", + "type": BuildManager, + "doc": "the BuildManager for this build", + }, + { + "name": "parent", + "type": (Proxy, AbstractContainer), + "doc": "the parent AbstractContainer/Proxy for the AbstractContainer being built", + "default": None, + }, + ) def construct(self, **kwargs): - ''' Construct an AbstractContainer from the given Builder ''' - builder, manager, parent = getargs('builder', 'manager', 'parent', kwargs) + """Construct an AbstractContainer from the given Builder""" + builder, manager, parent = getargs("builder", "manager", "parent", kwargs) cls = manager.get_cls(builder) # gather all subspecs subspecs = self.__get_subspec_values(builder, self.spec, manager) @@ -1230,8 +1430,8 @@ def construct(self, **kwargs): # there is no sub-specification that maps to that argument under the default logic if issubclass(cls, Data): if not isinstance(builder, DatasetBuilder): # pragma: no cover - raise ValueError('Can only construct a Data object from a DatasetBuilder - got %s' % type(builder)) - const_args['data'] = self.__check_ref_resolver(builder.data) + raise ValueError("Can only construct a Data object from a DatasetBuilder - got %s" % type(builder)) + const_args["data"] = self.__check_ref_resolver(builder.data) for subspec, value in subspecs.items(): const_arg = self.get_const_arg(subspec) if const_arg is not None: @@ -1243,7 +1443,7 @@ def construct(self, **kwargs): # build kwargs for the constructor kwargs = dict() for const_arg in get_docval(cls.__init__): - argname = const_arg['name'] + argname = const_arg["name"] override = self.__get_override_carg(argname, builder, manager) if override is not None: val = override @@ -1253,28 +1453,39 @@ def construct(self, **kwargs): continue kwargs[argname] = val try: - obj = self.__new_container__(cls, builder.source, parent, builder.attributes.get(self.__spec.id_key()), - **kwargs) + obj = self.__new_container__( + cls, + builder.source, + parent, + builder.attributes.get(self.__spec.id_key()), + **kwargs, + ) except Exception as ex: - msg = 'Could not construct %s object due to: %s' % (cls.__name__, ex) + msg = "Could not construct %s object due to: %s" % (cls.__name__, ex) raise ConstructError(builder, msg) from ex return obj def __new_container__(self, cls, container_source, parent, object_id, **kwargs): """A wrapper function for ensuring a container gets everything set appropriately""" - obj = cls.__new__(cls, container_source=container_source, parent=parent, object_id=object_id, - in_construct_mode=True) + obj = cls.__new__( + cls, + container_source=container_source, + parent=parent, + object_id=object_id, + in_construct_mode=True, + ) # obj has been created and is in construction mode, indicating that the object is being constructed by # the automatic construct process during read, rather than by the user obj.__init__(**kwargs) obj._in_construct_mode = False # reset to False to indicate that the construction of the object is complete return obj - @docval({'name': 'container', 'type': AbstractContainer, - 'doc': 'the AbstractContainer to get the Builder name for'}) + @docval( + {"name": "container", "type": AbstractContainer, "doc": "the AbstractContainer to get the Builder name for"} + ) def get_builder_name(self, **kwargs): - '''Get the name of a Builder that represents a AbstractContainer''' - container = getargs('container', kwargs) + """Get the name of a Builder that represents a AbstractContainer""" + container = getargs("container", kwargs) if self.__spec.name is not None: ret = self.__spec.name else: diff --git a/src/hdmf/build/warnings.py b/src/hdmf/build/warnings.py index 3d5f02126..e5960b8ad 100644 --- a/src/hdmf/build/warnings.py +++ b/src/hdmf/build/warnings.py @@ -5,6 +5,7 @@ class BuildWarning(UserWarning): """ Base class for warnings that are raised during the building of a container. """ + pass @@ -12,6 +13,7 @@ class IncorrectQuantityBuildWarning(BuildWarning): """ Raised when a container field contains a number of groups/datasets/links that is not allowed by the spec. """ + pass @@ -19,6 +21,7 @@ class MissingRequiredBuildWarning(BuildWarning): """ Raised when a required field is missing. """ + pass @@ -26,6 +29,7 @@ class MissingRequiredWarning(MissingRequiredBuildWarning): """ Raised when a required field is missing. """ + pass @@ -33,6 +37,7 @@ class OrphanContainerWarning(BuildWarning): """ Raised when a container is built without a parent. """ + pass @@ -40,4 +45,5 @@ class DtypeConversionWarning(UserWarning): """ Raised when a value is converted to a different data type in order to match the specification. """ + pass diff --git a/src/hdmf/common/__init__.py b/src/hdmf/common/__init__.py index 688e6105a..3e3675247 100644 --- a/src/hdmf/common/__init__.py +++ b/src/hdmf/common/__init__.py @@ -1,44 +1,61 @@ -'''This package will contain functions, classes, and objects +"""This package will contain functions, classes, and objects for reading and writing data in according to the HDMF-common specification -''' +""" import os.path from copy import deepcopy -CORE_NAMESPACE = 'hdmf-common' -EXP_NAMESPACE = 'hdmf-experimental' +CORE_NAMESPACE = "hdmf-common" +EXP_NAMESPACE = "hdmf-experimental" -from ..spec import NamespaceCatalog # noqa: E402 -from ..utils import docval, getargs, get_docval # noqa: E402 -from ..backends.io import HDMFIO # noqa: E402 from ..backends.hdf5 import HDF5IO # noqa: E402 -from ..validate import ValidatorMap # noqa: E402 +from ..backends.io import HDMFIO # noqa: E402 from ..build import BuildManager, TypeMap # noqa: E402 from ..container import _set_exp # noqa: E402 - +from ..spec import NamespaceCatalog # noqa: E402 +from ..utils import docval, get_docval, getargs # noqa: E402 +from ..validate import ValidatorMap # noqa: E402 # a global type map global __TYPE_MAP # a function to register a container classes with the global map -@docval({'name': 'data_type', 'type': str, 'doc': 'the data_type to get the spec for'}, - {'name': 'namespace', 'type': str, 'doc': 'the name of the namespace', 'default': CORE_NAMESPACE}, - {"name": "container_cls", "type": type, - "doc": "the class to map to the specified data_type", 'default': None}, - is_method=False) +@docval( + { + "name": "data_type", + "type": str, + "doc": "the data_type to get the spec for", + }, + { + "name": "namespace", + "type": str, + "doc": "the name of the namespace", + "default": CORE_NAMESPACE, + }, + { + "name": "container_cls", + "type": type, + "doc": "the class to map to the specified data_type", + "default": None, + }, + is_method=False, +) def register_class(**kwargs): """Register an Container class to use for reading and writing a data_type from a specification If container_cls is not specified, returns a decorator for registering an Container subclass as the class for data_type in namespace. """ - data_type, namespace, container_cls = getargs('data_type', 'namespace', 'container_cls', kwargs) + data_type, namespace, container_cls = getargs("data_type", "namespace", "container_cls", kwargs) if namespace == EXP_NAMESPACE: + def _dec(cls): _set_exp(cls) __TYPE_MAP.register_container_type(namespace, data_type, cls) return cls + else: + def _dec(cls): __TYPE_MAP.register_container_type(namespace, data_type, cls) return cls @@ -50,20 +67,31 @@ def _dec(cls): # a function to register an object mapper for a container class -@docval({"name": "container_cls", "type": type, - "doc": "the Container class for which the given ObjectMapper class gets used for"}, - {"name": "mapper_cls", "type": type, "doc": "the ObjectMapper class to use to map", 'default': None}, - is_method=False) +@docval( + { + "name": "container_cls", + "type": type, + "doc": "the Container class for which the given ObjectMapper class gets used for", + }, + { + "name": "mapper_cls", + "type": type, + "doc": "the ObjectMapper class to use to map", + "default": None, + }, + is_method=False, +) def register_map(**kwargs): """Register an ObjectMapper to use for a Container class type If mapper_cls is not specified, returns a decorator for registering an ObjectMapper class as the mapper for container_cls. If mapper_cls specified, register the class as the mapper for container_cls """ - container_cls, mapper_cls = getargs('container_cls', 'mapper_cls', kwargs) + container_cls, mapper_cls = getargs("container_cls", "mapper_cls", kwargs) def _dec(cls): __TYPE_MAP.register_map(container_cls, cls) return cls + if mapper_cls is None: return _dec else: @@ -78,11 +106,11 @@ def __get_resources(): from importlib_resources import files __location_of_this_file = files(__name__) - __core_ns_file_name = 'namespace.yaml' - __schema_dir = 'hdmf-common-schema/common' + __core_ns_file_name = "namespace.yaml" + __schema_dir = "hdmf-common-schema/common" ret = dict() - ret['namespace_path'] = str(__location_of_this_file / __schema_dir / __core_ns_file_name) + ret["namespace_path"] = str(__location_of_this_file / __schema_dir / __core_ns_file_name) return ret @@ -91,15 +119,21 @@ def _get_resources(): return __get_resources() -@docval({'name': 'namespace_path', 'type': str, - 'doc': 'the path to the YAML with the namespace definition'}, - returns="the namespaces loaded from the given file", rtype=tuple, - is_method=False) +@docval( + { + "name": "namespace_path", + "type": str, + "doc": "the path to the YAML with the namespace definition", + }, + returns="the namespaces loaded from the given file", + rtype=tuple, + is_method=False, +) def load_namespaces(**kwargs): - ''' + """ Load namespaces from file - ''' - namespace_path = getargs('namespace_path', kwargs) + """ + namespace_path = getargs("namespace_path", kwargs) return __TYPE_MAP.load_namespaces(namespace_path) @@ -108,28 +142,42 @@ def available_namespaces(): # a function to get the container class for a give type -@docval({'name': 'data_type', 'type': str, - 'doc': 'the data_type to get the Container class for'}, - {'name': 'namespace', 'type': str, 'doc': 'the namespace the data_type is defined in'}, - is_method=False) +@docval( + { + "name": "data_type", + "type": str, + "doc": "the data_type to get the Container class for", + }, + { + "name": "namespace", + "type": str, + "doc": "the namespace the data_type is defined in", + }, + is_method=False, +) def get_class(**kwargs): - """Get the class object of the Container subclass corresponding to a given neurdata_type. - """ - data_type, namespace = getargs('data_type', 'namespace', kwargs) + """Get the class object of the Container subclass corresponding to a given neurdata_type.""" + data_type, namespace = getargs("data_type", "namespace", kwargs) return __TYPE_MAP.get_dt_container_cls(data_type, namespace) -@docval({'name': 'extensions', 'type': (str, TypeMap, list), - 'doc': 'a path to a namespace, a TypeMap, or a list consisting paths to namespaces and TypeMaps', - 'default': None}, - returns="the namespaces loaded from the given file", rtype=tuple, - is_method=False) +@docval( + { + "name": "extensions", + "type": (str, TypeMap, list), + "doc": "a path to a namespace, a TypeMap, or a list consisting paths to namespaces and TypeMaps", + "default": None, + }, + returns="the namespaces loaded from the given file", + rtype=tuple, + is_method=False, +) def get_type_map(**kwargs): - ''' + """ Get a BuildManager to use for I/O using the given extensions. If no extensions are provided, return a BuildManager that uses the core namespace - ''' - extensions = getargs('extensions', kwargs) + """ + extensions = getargs("extensions", kwargs) type_map = None if extensions is None: type_map = deepcopy(__TYPE_MAP) @@ -145,7 +193,7 @@ def get_type_map(**kwargs): elif isinstance(ext, TypeMap): type_map.merge(ext) else: - msg = 'extensions must be a list of paths to namespace specs or a TypeMaps' + msg = "extensions must be a list of paths to namespace specs or a TypeMaps" raise ValueError(msg) elif isinstance(extensions, str): type_map.load_namespaces(extensions) @@ -154,29 +202,42 @@ def get_type_map(**kwargs): return type_map -@docval(*get_docval(get_type_map), - returns="a build manager with namespaces loaded from the given file", rtype=BuildManager, - is_method=False) +@docval( + *get_docval(get_type_map), + returns="a build manager with namespaces loaded from the given file", + rtype=BuildManager, + is_method=False, +) def get_manager(**kwargs): - ''' + """ Get a BuildManager to use for I/O using the given extensions. If no extensions are provided, return a BuildManager that uses the core namespace - ''' + """ type_map = get_type_map(**kwargs) return BuildManager(type_map) -@docval({'name': 'io', 'type': HDMFIO, - 'doc': 'the HDMFIO object to read from'}, - {'name': 'namespace', 'type': str, - 'doc': 'the namespace to validate against', 'default': CORE_NAMESPACE}, - {'name': 'experimental', 'type': bool, - 'doc': 'data type is an experimental data type', 'default': False}, - returns="errors in the file", rtype=list, - is_method=False) +@docval( + {"name": "io", "type": HDMFIO, "doc": "the HDMFIO object to read from"}, + { + "name": "namespace", + "type": str, + "doc": "the namespace to validate against", + "default": CORE_NAMESPACE, + }, + { + "name": "experimental", + "type": bool, + "doc": "data type is an experimental data type", + "default": False, + }, + returns="errors in the file", + rtype=list, + is_method=False, +) def validate(**kwargs): """Validate an file against a namespace""" - io, namespace, experimental = getargs('io', 'namespace', 'experimental', kwargs) + io, namespace, experimental = getargs("io", "namespace", "experimental", kwargs) if experimental: namespace = EXP_NAMESPACE builder = io.read_builder() @@ -189,47 +250,48 @@ def get_hdf5io(**kwargs): """ A convenience method for getting an HDF5IO object using an HDMF-common build manager if none is provided. """ - manager = getargs('manager', kwargs) + manager = getargs("manager", kwargs) if manager is None: - kwargs['manager'] = get_manager() + kwargs["manager"] = get_manager() return HDF5IO(**kwargs) # load the hdmf-common namespace __resources = __get_resources() -if os.path.exists(__resources['namespace_path']): +if os.path.exists(__resources["namespace_path"]): __TYPE_MAP = TypeMap(NamespaceCatalog()) - load_namespaces(__resources['namespace_path']) + load_namespaces(__resources["namespace_path"]) # import these so the TypeMap gets populated - from . import io as __io # noqa: E402 - - from . import table # noqa: E402 from . import alignedtable # noqa: E402 - from . import sparse # noqa: E402 - from . import resources # noqa: E402 from . import multi # noqa: E402 + from . import resources # noqa: E402 + from . import sparse # noqa: E402 + from . import table # noqa: E402 + from . import io as __io # noqa: E402 # register custom class generators from .io.table import DynamicTableGenerator + __TYPE_MAP.register_generator(DynamicTableGenerator) - from .. import Data, Container - __TYPE_MAP.register_container_type(CORE_NAMESPACE, 'Container', Container) - __TYPE_MAP.register_container_type(CORE_NAMESPACE, 'Data', Data) + from .. import Container, Data + + __TYPE_MAP.register_container_type(CORE_NAMESPACE, "Container", Container) + __TYPE_MAP.register_container_type(CORE_NAMESPACE, "Data", Data) else: raise RuntimeError("Unable to load a TypeMap - no namespace file found") -DynamicTable = get_class('DynamicTable', CORE_NAMESPACE) -VectorData = get_class('VectorData', CORE_NAMESPACE) -VectorIndex = get_class('VectorIndex', CORE_NAMESPACE) -ElementIdentifiers = get_class('ElementIdentifiers', CORE_NAMESPACE) -DynamicTableRegion = get_class('DynamicTableRegion', CORE_NAMESPACE) -EnumData = get_class('EnumData', EXP_NAMESPACE) -CSRMatrix = get_class('CSRMatrix', CORE_NAMESPACE) -ExternalResources = get_class('ExternalResources', EXP_NAMESPACE) -SimpleMultiContainer = get_class('SimpleMultiContainer', CORE_NAMESPACE) -AlignedDynamicTable = get_class('AlignedDynamicTable', CORE_NAMESPACE) +DynamicTable = get_class("DynamicTable", CORE_NAMESPACE) +VectorData = get_class("VectorData", CORE_NAMESPACE) +VectorIndex = get_class("VectorIndex", CORE_NAMESPACE) +ElementIdentifiers = get_class("ElementIdentifiers", CORE_NAMESPACE) +DynamicTableRegion = get_class("DynamicTableRegion", CORE_NAMESPACE) +EnumData = get_class("EnumData", EXP_NAMESPACE) +CSRMatrix = get_class("CSRMatrix", CORE_NAMESPACE) +ExternalResources = get_class("ExternalResources", EXP_NAMESPACE) +SimpleMultiContainer = get_class("SimpleMultiContainer", CORE_NAMESPACE) +AlignedDynamicTable = get_class("AlignedDynamicTable", CORE_NAMESPACE) diff --git a/src/hdmf/common/alignedtable.py b/src/hdmf/common/alignedtable.py index 2cc20bbdc..0dc862141 100644 --- a/src/hdmf/common/alignedtable.py +++ b/src/hdmf/common/alignedtable.py @@ -6,12 +6,12 @@ import numpy as np import pandas as pd +from ..utils import AllowPositional, docval, get_docval, getargs, popargs from . import register_class from .table import DynamicTable -from ..utils import docval, getargs, popargs, get_docval, AllowPositional -@register_class('AlignedDynamicTable') +@register_class("AlignedDynamicTable") class AlignedDynamicTable(DynamicTable): """ DynamicTable container that supports storing a collection of subtables. Each sub-table is a @@ -25,27 +25,42 @@ class AlignedDynamicTable(DynamicTable): columns of the main table (not including the category tables). To get the full list of column names, use the get_colnames() function instead. """ - __fields__ = ({'name': 'category_tables', 'child': True}, ) - - @docval(*get_docval(DynamicTable.__init__), - {'name': 'category_tables', 'type': list, - 'doc': 'List of DynamicTables to be added to the container. NOTE: Only regular ' - 'DynamicTables are allowed. Using AlignedDynamicTable as a category for ' - 'AlignedDynamicTable is currently not supported.', 'default': None}, - {'name': 'categories', 'type': 'array_data', - 'doc': 'List of names with the ordering of category tables', 'default': None}, - allow_positional=AllowPositional.WARNING) + + __fields__ = ({"name": "category_tables", "child": True},) + + @docval( + *get_docval(DynamicTable.__init__), + { + "name": "category_tables", + "type": list, + "doc": ( + "List of DynamicTables to be added to the container. NOTE: Only regular" + " DynamicTables are allowed. Using AlignedDynamicTable as a category" + " for AlignedDynamicTable is currently not supported." + ), + "default": None, + }, + { + "name": "categories", + "type": "array_data", + "doc": "List of names with the ordering of category tables", + "default": None, + }, + allow_positional=AllowPositional.WARNING, + ) def __init__(self, **kwargs): # noqa: C901 - in_category_tables = popargs('category_tables', kwargs) - in_categories = popargs('categories', kwargs) + in_category_tables = popargs("category_tables", kwargs) + in_categories = popargs("categories", kwargs) if in_category_tables is not None: # Error check to make sure that all category_table are regular DynamicTable for i, v in enumerate(in_category_tables): if not isinstance(v, DynamicTable): raise ValueError("Category table with index %i is not a DynamicTable" % i) if isinstance(v, AlignedDynamicTable): - raise ValueError("Category table with index %i is an AlignedDynamicTable. " - "Nesting of AlignedDynamicTable is currently not supported." % i) + raise ValueError( + "Category table with index %i is an AlignedDynamicTable. " + "Nesting of AlignedDynamicTable is currently not supported." % i + ) # set in_categories from the in_category_tables if it is empty if in_categories is None and in_category_tables is not None: in_categories = [tab.name for tab in in_category_tables] @@ -55,8 +70,10 @@ def __init__(self, **kwargs): # noqa: C901 # at this point both in_categories and in_category_tables should either both be None or both be a list if in_categories is not None: if len(in_categories) != len(in_category_tables): - raise ValueError("%s category_tables given but %s categories specified" % - (len(in_category_tables), len(in_categories))) + raise ValueError( + "%s category_tables given but %s categories specified" + % (len(in_category_tables), len(in_categories)) + ) # Initialize the main dynamic table super().__init__(**kwargs) # Create and set all sub-categories @@ -88,13 +105,16 @@ def __init__(self, **kwargs): # noqa: C901 # name in the in_categories list. We, therefore, exclude this check from coverage testing # but we leave it in just as a backup trigger in case something unexpected happens if tabel_index < 0: # pragma: no cover - raise ValueError("DynamicTable %s listed in categories but does not appear in category_tables" % - table_name) # pragma: no cover + raise ValueError( + "DynamicTable %s listed in categories but does not appear in category_tables" % table_name + ) # pragma: no cover # Test that all category tables have the correct number of rows category = in_category_tables[tabel_index] if len(category) != len(self): - raise ValueError('Category DynamicTable %s does not align, it has %i rows expected %i' % - (category.name, len(category), len(self))) + raise ValueError( + "Category DynamicTable %s does not align, it has %i rows expected %i" + % (category.name, len(category), len(self)) + ) # Add the category table to our category_tables. dts[category.name] = category # Set the self.category_tables attribute, which will set the parent/child relationships for the category_tables @@ -127,7 +147,9 @@ def categories(self): """ return list(self.category_tables.keys()) - @docval({'name': 'category', 'type': DynamicTable, 'doc': 'Add a new DynamicTable category'},) + @docval( + {"name": "category", "type": DynamicTable, "doc": "Add a new DynamicTable category"}, + ) def add_category(self, **kwargs): """ Add a new DynamicTable to the AlignedDynamicTable to create a new category in the table. @@ -139,29 +161,32 @@ def add_category(self, **kwargs): :raises: ValueError is raised if the input table does not have the same number of rows as the main table. ValueError is raised if the table is an AlignedDynamicTable instead of regular DynamicTable. """ - category = getargs('category', kwargs) + category = getargs("category", kwargs) if len(category) != len(self): - raise ValueError('New category DynamicTable does not align, it has %i rows expected %i' % - (len(category), len(self))) + raise ValueError( + "New category DynamicTable does not align, it has %i rows expected %i" % (len(category), len(self)) + ) if category.name in self.category_tables: raise ValueError("Category %s already in the table" % category.name) if isinstance(category, AlignedDynamicTable): - raise ValueError("Category is an AlignedDynamicTable. Nesting of AlignedDynamicTable " - "is currently not supported.") + raise ValueError( + "Category is an AlignedDynamicTable. Nesting of AlignedDynamicTable is currently not supported." + ) self.category_tables[category.name] = category category.parent = self - @docval({'name': 'name', 'type': str, 'doc': 'Name of the category we want to retrieve', 'default': None}) + @docval({"name": "name", "type": str, "doc": "Name of the category we want to retrieve", "default": None}) def get_category(self, **kwargs): - name = popargs('name', kwargs) + name = popargs("name", kwargs) if name is None or (name not in self.category_tables and name == self.name): return self else: return self.category_tables[name] - @docval(*get_docval(DynamicTable.add_column), - {'name': 'category', 'type': str, 'doc': 'The category the column should be added to', - 'default': None}) + @docval( + *get_docval(DynamicTable.add_column), + {"name": "category", "type": str, "doc": "The category the column should be added to", "default": None}, + ) def add_column(self, **kwargs): """ Add a column to the table @@ -169,7 +194,7 @@ def add_column(self, **kwargs): :raises: KeyError if the category does not exist """ - category_name = popargs('category', kwargs) + category_name = popargs("category", kwargs) if category_name is None: # Add the column to our main table super().add_column(**kwargs) @@ -181,16 +206,22 @@ def add_column(self, **kwargs): raise KeyError("Category %s not in table" % category_name) category.add_column(**kwargs) - @docval({'name': 'data', 'type': dict, 'doc': 'the data to put in this row', 'default': None}, - {'name': 'id', 'type': int, 'doc': 'the ID for the row', 'default': None}, - {'name': 'enforce_unique_id', 'type': bool, 'doc': 'enforce that the id in the table must be unique', - 'default': False}, - allow_extra=True) + @docval( + {"name": "data", "type": dict, "doc": "the data to put in this row", "default": None}, + {"name": "id", "type": int, "doc": "the ID for the row", "default": None}, + { + "name": "enforce_unique_id", + "type": bool, + "doc": "enforce that the id in the table must be unique", + "default": False, + }, + allow_extra=True, + ) def add_row(self, **kwargs): """ We can either provide the row data as a single dict or by specifying a dict for each category """ - data, row_id, enforce_unique_id = popargs('data', 'id', 'enforce_unique_id', kwargs) + data, row_id, enforce_unique_id = popargs("data", "id", "enforce_unique_id", kwargs) data = data if data is not None else kwargs # extract the category data @@ -200,24 +231,36 @@ def add_row(self, **kwargs): missing_categories = set(self.categories) - set(list(category_data.keys())) if missing_categories: raise KeyError( - '\n'.join([ - 'row data keys do not match available categories', - 'missing {} category keys: {}'.format(len(missing_categories), missing_categories) - ]) + "\n".join( + [ + "row data keys do not match available categories", + "missing {} category keys: {}".format(len(missing_categories), missing_categories), + ] + ) ) # Add the data to our main dynamic table - data['id'] = row_id - data['enforce_unique_id'] = enforce_unique_id + data["id"] = row_id + data["enforce_unique_id"] = enforce_unique_id super().add_row(**data) # Add the data to all out dynamic table categories for category, values in category_data.items(): self.category_tables[category].add_row(**values) - @docval({'name': 'include_category_tables', 'type': bool, - 'doc': "Ignore sub-category tables and just look at the main table", 'default': False}, - {'name': 'ignore_category_ids', 'type': bool, - 'doc': "Ignore id columns of sub-category tables", 'default': False}) + @docval( + { + "name": "include_category_tables", + "type": bool, + "doc": "Ignore sub-category tables and just look at the main table", + "default": False, + }, + { + "name": "ignore_category_ids", + "type": bool, + "doc": "Ignore id columns of sub-category tables", + "default": False, + }, + ) def get_colnames(self, **kwargs): """Get the full list of names of columns for this table @@ -225,29 +268,41 @@ def get_colnames(self, **kwargs): that contains the column and the second string is the name of the column. If include_category_tables is False, then a list of column names is returned. """ - if not getargs('include_category_tables', kwargs): + if not getargs("include_category_tables", kwargs): return self.colnames else: - ignore_category_ids = getargs('ignore_category_ids', kwargs) + ignore_category_ids = getargs("ignore_category_ids", kwargs) columns = [(self.name, c) for c in self.colnames] for category in self.category_tables.values(): if not ignore_category_ids: - columns += [(category.name, 'id'), ] + columns += [ + (category.name, "id"), + ] columns += [(category.name, c) for c in category.colnames] return columns - @docval({'name': 'ignore_category_ids', 'type': bool, - 'doc': "Ignore id columns of sub-category tables", 'default': False}) + @docval( + { + "name": "ignore_category_ids", + "type": bool, + "doc": "Ignore id columns of sub-category tables", + "default": False, + } + ) def to_dataframe(self, **kwargs): """Convert the collection of tables to a single pandas DataFrame""" - dfs = [super().to_dataframe().reset_index(), ] - if getargs('ignore_category_ids', kwargs): + dfs = [ + super().to_dataframe().reset_index(), + ] + if getargs("ignore_category_ids", kwargs): dfs += [category.to_dataframe() for category in self.category_tables.values()] else: dfs += [category.to_dataframe().reset_index() for category in self.category_tables.values()] - names = [self.name, ] + list(self.category_tables.keys()) + names = [ + self.name, + ] + list(self.category_tables.keys()) res = pd.concat(dfs, axis=1, keys=names) - res.set_index((self.name, 'id'), drop=True, inplace=True) + res.set_index((self.name, "id"), drop=True, inplace=True) return res def __getitem__(self, item): @@ -307,11 +362,14 @@ def get(self, item, **kwargs): """ if isinstance(item, (int, list, np.ndarray, slice)): # get a single full row from all tables - dfs = ([super().get(item, **kwargs).reset_index(), ] + - [category[item].reset_index() for category in self.category_tables.values()]) - names = [self.name, ] + list(self.category_tables.keys()) + dfs = [ + super().get(item, **kwargs).reset_index(), + ] + [category[item].reset_index() for category in self.category_tables.values()] + names = [ + self.name, + ] + list(self.category_tables.keys()) res = pd.concat(dfs, axis=1, keys=names) - res.set_index((self.name, 'id'), drop=True, inplace=True) + res.set_index((self.name, "id"), drop=True, inplace=True) return res elif isinstance(item, str) or item is None: if item in self.colnames: @@ -352,20 +410,28 @@ def get(self, item, **kwargs): else: return self.get_category(item[0])[item[1]][item[2]] else: - raise ValueError("Expected tuple of length 2 of the form [category, column], [row, category], " - "[row, (category, column)] or a tuple of length 3 of the form " - "[category, column, row], [row, category, column]") - - @docval({'name': 'ignore_category_tables', 'type': bool, - 'doc': "Ignore the category tables and only check in the main table columns", 'default': False}, - allow_extra=False) + raise ValueError( + "Expected tuple of length 2 of the form [category, column], [row," + " category], [row, (category, column)] or a tuple of length 3 of" + " the form [category, column, row], [row, category, column]" + ) + + @docval( + { + "name": "ignore_category_tables", + "type": bool, + "doc": "Ignore the category tables and only check in the main table columns", + "default": False, + }, + allow_extra=False, + ) def has_foreign_columns(self, **kwargs): """ Does the table contain DynamicTableRegion columns :returns: True if the table or any of the category tables contains a DynamicTableRegion column, else False """ - ignore_category_tables = getargs('ignore_category_tables', kwargs) + ignore_category_tables = getargs("ignore_category_tables", kwargs) if super().has_foreign_columns(): return True if not ignore_category_tables: @@ -374,9 +440,15 @@ def has_foreign_columns(self, **kwargs): return True return False - @docval({'name': 'ignore_category_tables', 'type': bool, - 'doc': "Ignore the category tables and only check in the main table columns", 'default': False}, - allow_extra=False) + @docval( + { + "name": "ignore_category_tables", + "type": bool, + "doc": "Ignore the category tables and only check in the main table columns", + "default": False, + }, + allow_extra=False, + ) def get_foreign_columns(self, **kwargs): """ Determine the names of all columns that link to another DynamicTable, i.e., @@ -387,17 +459,23 @@ def get_foreign_columns(self, **kwargs): category table (or None if the column is in the main table) and the second string is the column name. """ - ignore_category_tables = getargs('ignore_category_tables', kwargs) + ignore_category_tables = getargs("ignore_category_tables", kwargs) col_names = [(None, col_name) for col_name in super().get_foreign_columns()] if not ignore_category_tables: for table in self.category_tables.values(): col_names += [(table.name, col_name) for col_name in table.get_foreign_columns()] return col_names - @docval(*get_docval(DynamicTable.get_linked_tables), - {'name': 'ignore_category_tables', 'type': bool, - 'doc': "Ignore the category tables and only check in the main table columns", 'default': False}, - allow_extra=False) + @docval( + *get_docval(DynamicTable.get_linked_tables), + { + "name": "ignore_category_tables", + "type": bool, + "doc": "Ignore the category tables and only check in the main table columns", + "default": False, + }, + allow_extra=False, + ) def get_linked_tables(self, **kwargs): """ Get a list of the full list of all tables that are being linked to directly or indirectly @@ -411,6 +489,6 @@ def get_linked_tables(self, **kwargs): * 'target_table' : The target DynamicTable; same as source_column.table. """ - ignore_category_tables = getargs('ignore_category_tables', kwargs) + ignore_category_tables = getargs("ignore_category_tables", kwargs) other_tables = None if ignore_category_tables else list(self.category_tables.values()) return super().get_linked_tables(other_tables=other_tables) diff --git a/src/hdmf/common/hierarchicaltable.py b/src/hdmf/common/hierarchicaltable.py index 8322d2d73..2bc80f764 100644 --- a/src/hdmf/common/hierarchicaltable.py +++ b/src/hdmf/common/hierarchicaltable.py @@ -2,18 +2,24 @@ Module providing additional functionality for dealing with hierarchically nested tables, i.e., tables containing DynamicTableRegion references. """ -import pandas as pd import numpy as np -from hdmf.common.table import DynamicTable, DynamicTableRegion, VectorIndex +import pandas as pd + from hdmf.common.alignedtable import AlignedDynamicTable +from hdmf.common.table import DynamicTable, DynamicTableRegion, VectorIndex from hdmf.utils import docval, getargs -@docval({'name': 'dynamic_table', 'type': DynamicTable, - 'doc': 'DynamicTable object to be converted to a hierarchical pandas.Dataframe'}, - returns="Hierarchical pandas.DataFrame with usually a pandas.MultiIndex on both the index and columns.", - rtype='pandas.DataFrame', - is_method=False) +@docval( + { + "name": "dynamic_table", + "type": DynamicTable, + "doc": "DynamicTable object to be converted to a hierarchical pandas.Dataframe", + }, + returns="Hierarchical pandas.DataFrame with usually a pandas.MultiIndex on both the index and columns.", + rtype="pandas.DataFrame", + is_method=False, +) def to_hierarchical_dataframe(dynamic_table): """ Create a hierarchical pandas.DataFrame that represents all data from a collection of linked DynamicTables. @@ -36,7 +42,7 @@ def to_hierarchical_dataframe(dynamic_table): # if table does not contain any DynamicTableRegion columns then we can just convert it to a dataframe if len(foreign_columns) == 0: return dynamic_table.to_dataframe() - hcol_name = foreign_columns[0] # We only denormalize the first foreign column for now + hcol_name = foreign_columns[0] # We only denormalize the first foreign column for now hcol = dynamic_table[hcol_name] # Either a VectorIndex pointing to a DynamicTableRegion or a DynamicTableRegion # Get the target DynamicTable that hcol is pointing to. If hcol is a VectorIndex then we first need # to get the target of it before we look up the table. @@ -56,13 +62,15 @@ def to_hierarchical_dataframe(dynamic_table): if isinstance(hcol, VectorIndex): rows = hcol.get(slice(None), index=False, df=True) else: - rows = [hcol[i:(i+1)] for i in range(len(hcol))] + rows = [hcol[i : (i + 1)] for i in range(len(hcol))] # Retrieve the columns we need to iterate over from our input table. For AlignedDynamicTable we need to # use the get_colnames function instead of the colnames property to ensure we get all columns not just # the columns from the main table - dynamic_table_colnames = (dynamic_table.get_colnames(include_category_tables=True, ignore_category_ids=False) - if isinstance(dynamic_table, AlignedDynamicTable) - else dynamic_table.colnames) + dynamic_table_colnames = ( + dynamic_table.get_colnames(include_category_tables=True, ignore_category_ids=False) + if isinstance(dynamic_table, AlignedDynamicTable) + else dynamic_table.colnames + ) # Case 1: Our DynamicTableRegion column points to a DynamicTable that itself does not contain # any DynamicTableRegion references (i.e., we have reached the end of our table hierarchy). @@ -79,9 +87,9 @@ def to_hierarchical_dataframe(dynamic_table): # Determine the multi-index tuple for our row, consisting of: i) id of the row in this # table, ii) all columns (except the hierarchical column we are flattening), and # iii) the index (i.e., id) from our target row - index_data = ([dynamic_table.id[row_index], ] + - [dynamic_table[row_index, colname] - for colname in dynamic_table_colnames if colname != hcol_name]) + index_data = [ + dynamic_table.id[row_index], + ] + [dynamic_table[row_index, colname] for colname in dynamic_table_colnames if colname != hcol_name] index.append(tuple(index_data)) # Determine the names for our index and columns of our output table @@ -89,16 +97,22 @@ def to_hierarchical_dataframe(dynamic_table): # NOTE: While for a regular DynamicTable the "colnames" property will give us the full list of column names, # for AlignedDynamicTable we need to use the get_colnames() function instead to make sure we include # the category table columns as well. - index_names = ([(dynamic_table.name, 'id')] + - [(dynamic_table.name, colname) - for colname in dynamic_table_colnames if colname != hcol_name]) + index_names = [(dynamic_table.name, "id")] + [ + (dynamic_table.name, colname) for colname in dynamic_table_colnames if colname != hcol_name + ] # Determine the name of our columns - hcol_iter_columns = (hcol_target.get_colnames(include_category_tables=True, ignore_category_ids=False) - if isinstance(hcol_target, AlignedDynamicTable) - else hcol_target.colnames) - columns = pd.MultiIndex.from_tuples([(hcol_target.name, 'id'), ] + - [(hcol_target.name, c) for c in hcol_iter_columns], - names=('source_table', 'label')) + hcol_iter_columns = ( + hcol_target.get_colnames(include_category_tables=True, ignore_category_ids=False) + if isinstance(hcol_target, AlignedDynamicTable) + else hcol_target.colnames + ) + columns = pd.MultiIndex.from_tuples( + [ + (hcol_target.name, "id"), + ] + + [(hcol_target.name, c) for c in hcol_iter_columns], + names=("source_table", "label"), + ) # Case 2: Our DynamicTableRegion columns points to another table with a DynamicTableRegion, i.e., # we need to recursively resolve more levels of the table hierarchy @@ -121,17 +135,25 @@ def to_hierarchical_dataframe(dynamic_table): # Determine the column data for our row. data.append(row_tuple_level3[1:]) # Determine the multi-index tuple for our row, - index_data = ([dynamic_table.id[row_index], ] + - [dynamic_table[row_index, colname] - for colname in dynamic_table_colnames if colname != hcol_name] + - list(row_tuple_level3[0])) + index_data = ( + [ + dynamic_table.id[row_index], + ] + + [ + dynamic_table[row_index, colname] + for colname in dynamic_table_colnames + if colname != hcol_name + ] + + list(row_tuple_level3[0]) + ) index.append(tuple(index_data)) # Determine the names for our index and columns of our output table # We need to do this even if our table was empty (i.e. even is len(rows)==0) - index_names = ([(dynamic_table.name, "id")] + - [(dynamic_table.name, colname) - for colname in dynamic_table_colnames if colname != hcol_name] + - hcol_hdf.index.names) + index_names = ( + [(dynamic_table.name, "id")] + + [(dynamic_table.name, colname) for colname in dynamic_table_colnames if colname != hcol_name] + + hcol_hdf.index.names + ) columns = hcol_hdf.columns # Check if the index contains any unhashable types. If a table contains a VectorIndex column @@ -187,20 +209,31 @@ def __flatten_column_name(col): if isinstance(v, tuple): temp += list(v) else: - temp += [v, ] + temp += [ + v, + ] re = temp return tuple(re) else: return col -@docval({'name': 'dataframe', 'type': pd.DataFrame, - 'doc': 'Pandas dataframe to update (usually generated by the to_hierarchical_dataframe function)'}, - {'name': 'inplace', 'type': 'bool', 'doc': 'Update the dataframe inplace or return a modified copy', - 'default': False}, - returns="pandas.DataFrame with the id columns removed", - rtype='pandas.DataFrame', - is_method=False) +@docval( + { + "name": "dataframe", + "type": pd.DataFrame, + "doc": "Pandas dataframe to update (usually generated by the to_hierarchical_dataframe function)", + }, + { + "name": "inplace", + "type": "bool", + "doc": "Update the dataframe inplace or return a modified copy", + "default": False, + }, + returns="pandas.DataFrame with the id columns removed", + rtype="pandas.DataFrame", + is_method=False, +) def drop_id_columns(**kwargs): """ Drop all columns named 'id' from the table. @@ -212,8 +245,8 @@ def drop_id_columns(**kwargs): :raises TypeError: In case that dataframe parameter is not a pandas.Dataframe. """ - dataframe, inplace = getargs('dataframe', 'inplace', kwargs) - col_name = 'id' + dataframe, inplace = getargs("dataframe", "inplace", kwargs) + col_name = "id" drop_labels = [] for col in dataframe.columns: if __get_col_name(col) == col_name: @@ -222,19 +255,33 @@ def drop_id_columns(**kwargs): return dataframe if inplace else re -@docval({'name': 'dataframe', 'type': pd.DataFrame, - 'doc': 'Pandas dataframe to update (usually generated by the to_hierarchical_dataframe function)'}, - {'name': 'max_levels', 'type': (int, np.integer), - 'doc': 'Maximum number of levels to use in the resulting column Index. NOTE: When ' - 'limiting the number of levels the function simply removes levels from the ' - 'beginning. As such, removing levels may result in columns with duplicate names.' - 'Value must be >0.', - 'default': None}, - {'name': 'inplace', 'type': 'bool', 'doc': 'Update the dataframe inplace or return a modified copy', - 'default': False}, - returns="pandas.DataFrame with a regular pandas.Index columns rather and a pandas.MultiIndex", - rtype='pandas.DataFrame', - is_method=False) +@docval( + { + "name": "dataframe", + "type": pd.DataFrame, + "doc": "Pandas dataframe to update (usually generated by the to_hierarchical_dataframe function)", + }, + { + "name": "max_levels", + "type": (int, np.integer), + "doc": ( + "Maximum number of levels to use in the resulting column Index. NOTE: When" + " limiting the number of levels the function simply removes levels from the" + " beginning. As such, removing levels may result in columns with duplicate" + " names.Value must be >0." + ), + "default": None, + }, + { + "name": "inplace", + "type": "bool", + "doc": "Update the dataframe inplace or return a modified copy", + "default": False, + }, + returns="pandas.DataFrame with a regular pandas.Index columns rather and a pandas.MultiIndex", + rtype="pandas.DataFrame", + is_method=False, +) def flatten_column_index(**kwargs): """ Flatten the column index of a pandas DataFrame. @@ -247,9 +294,9 @@ def flatten_column_index(**kwargs): :raises ValueError: In case the num_levels is not >0 :raises TypeError: In case that dataframe parameter is not a pandas.Dataframe. """ - dataframe, max_levels, inplace = getargs('dataframe', 'max_levels', 'inplace', kwargs) + dataframe, max_levels, inplace = getargs("dataframe", "max_levels", "inplace", kwargs) if max_levels is not None and max_levels <= 0: - raise ValueError('max_levels must be greater than 0') + raise ValueError("max_levels must be greater than 0") # Compute the new column names col_names = [__flatten_column_name(col) for col in dataframe.columns.values] # Apply the max_levels filter. Make sure to do this only for columns that are actually tuples diff --git a/src/hdmf/common/io/__init__.py b/src/hdmf/common/io/__init__.py index 27c13df27..ed72255f8 100644 --- a/src/hdmf/common/io/__init__.py +++ b/src/hdmf/common/io/__init__.py @@ -1,4 +1,4 @@ +from . import alignedtable from . import multi -from . import table from . import resources -from . import alignedtable +from . import table diff --git a/src/hdmf/common/io/alignedtable.py b/src/hdmf/common/io/alignedtable.py index 3ff7f8d3f..0006276f0 100644 --- a/src/hdmf/common/io/alignedtable.py +++ b/src/hdmf/common/io/alignedtable.py @@ -8,8 +8,9 @@ class AlignedDynamicTableMap(DynamicTableMap): """ Customize the mapping for AlignedDynamicTable """ + def __init__(self, spec): super().__init__(spec) # By default the DynamicTables contained as sub-categories in the AlignedDynamicTable are mapped to # the 'dynamic_tables' class attribute. This renames the attribute to 'category_tables' - self.map_spec('category_tables', spec.get_data_type('DynamicTable')) + self.map_spec("category_tables", spec.get_data_type("DynamicTable")) diff --git a/src/hdmf/common/io/multi.py b/src/hdmf/common/io/multi.py index c2493255d..e2c4d8db9 100644 --- a/src/hdmf/common/io/multi.py +++ b/src/hdmf/common/io/multi.py @@ -1,23 +1,21 @@ -from .. import register_map -from ..multi import SimpleMultiContainer from ...build import ObjectMapper from ...container import Container, Data +from .. import register_map +from ..multi import SimpleMultiContainer @register_map(SimpleMultiContainer) class SimpleMultiContainerMap(ObjectMapper): - - @ObjectMapper.object_attr('containers') + @ObjectMapper.object_attr("containers") def containers_attr(self, container, manager): return [c for c in container.containers.values() if isinstance(c, Container)] - @ObjectMapper.constructor_arg('containers') + @ObjectMapper.constructor_arg("containers") def containers_carg(self, builder, manager): - return [manager.construct(sub) for sub in builder.datasets.values() - if manager.is_sub_data_type(sub, 'Data')] + \ - [manager.construct(sub) for sub in builder.groups.values() - if manager.is_sub_data_type(sub, 'Container')] + return [ + manager.construct(sub) for sub in builder.datasets.values() if manager.is_sub_data_type(sub, "Data") + ] + [manager.construct(sub) for sub in builder.groups.values() if manager.is_sub_data_type(sub, "Container")] - @ObjectMapper.object_attr('datas') + @ObjectMapper.object_attr("datas") def datas_attr(self, container, manager): return [c for c in container.containers.values() if isinstance(c, Data)] diff --git a/src/hdmf/common/io/resources.py b/src/hdmf/common/io/resources.py index 6ecf7088a..3d4463fb6 100644 --- a/src/hdmf/common/io/resources.py +++ b/src/hdmf/common/io/resources.py @@ -1,16 +1,22 @@ -from .. import register_map -from ..resources import ExternalResources, KeyTable, FileTable, ObjectTable, ObjectKeyTable, EntityTable from ...build import ObjectMapper +from .. import register_map +from ..resources import ( + EntityTable, + ExternalResources, + FileTable, + KeyTable, + ObjectKeyTable, + ObjectTable, +) @register_map(ExternalResources) class ExternalResourcesMap(ObjectMapper): - def construct_helper(self, name, parent_builder, table_cls, manager): """Create a new instance of table_cls with data from parent_builder[name]. - The DatasetBuilder for name is associated with data_type Data and container class Data, - but users should use the more specific table_cls for these datasets. + The DatasetBuilder for name is associated with data_type Data and container class Data, + but users should use the more specific table_cls for these datasets. """ parent = manager._get_proxy_builder(parent_builder) builder = parent_builder[name] @@ -19,22 +25,22 @@ def construct_helper(self, name, parent_builder, table_cls, manager): kwargs = dict(name=builder.name, data=builder.data) return self.__new_container__(table_cls, src, parent, oid, **kwargs) - @ObjectMapper.constructor_arg('keys') + @ObjectMapper.constructor_arg("keys") def keys(self, builder, manager): - return self.construct_helper('keys', builder, KeyTable, manager) + return self.construct_helper("keys", builder, KeyTable, manager) - @ObjectMapper.constructor_arg('files') + @ObjectMapper.constructor_arg("files") def files(self, builder, manager): - return self.construct_helper('files', builder, FileTable, manager) + return self.construct_helper("files", builder, FileTable, manager) - @ObjectMapper.constructor_arg('entities') + @ObjectMapper.constructor_arg("entities") def entities(self, builder, manager): - return self.construct_helper('entities', builder, EntityTable, manager) + return self.construct_helper("entities", builder, EntityTable, manager) - @ObjectMapper.constructor_arg('objects') + @ObjectMapper.constructor_arg("objects") def objects(self, builder, manager): - return self.construct_helper('objects', builder, ObjectTable, manager) + return self.construct_helper("objects", builder, ObjectTable, manager) - @ObjectMapper.constructor_arg('object_keys') + @ObjectMapper.constructor_arg("object_keys") def object_keys(self, builder, manager): - return self.construct_helper('object_keys', builder, ObjectKeyTable, manager) + return self.construct_helper("object_keys", builder, ObjectKeyTable, manager) diff --git a/src/hdmf/common/io/table.py b/src/hdmf/common/io/table.py index 0cde4de9e..351acec3d 100644 --- a/src/hdmf/common/io/table.py +++ b/src/hdmf/common/io/table.py @@ -1,52 +1,54 @@ -from .. import register_map -from ..table import DynamicTable, VectorData, VectorIndex, DynamicTableRegion -from ...build import ObjectMapper, BuildManager, CustomClassGenerator +from ...build import BuildManager, CustomClassGenerator, ObjectMapper from ...spec import Spec -from ...utils import docval, getargs, popargs, AllowPositional +from ...utils import AllowPositional, docval, getargs, popargs +from .. import register_map +from ..table import DynamicTable, DynamicTableRegion, VectorData, VectorIndex @register_map(DynamicTable) class DynamicTableMap(ObjectMapper): - def __init__(self, spec): super().__init__(spec) - vector_data_spec = spec.get_data_type('VectorData') - self.map_spec('columns', vector_data_spec) + vector_data_spec = spec.get_data_type("VectorData") + self.map_spec("columns", vector_data_spec) - @ObjectMapper.object_attr('colnames') + @ObjectMapper.object_attr("colnames") def attr_columns(self, container, manager): if all(not col for col in container.columns): return tuple() return container.colnames - @docval({"name": "spec", "type": Spec, "doc": "the spec to get the attribute value for"}, - {"name": "container", "type": DynamicTable, "doc": "the container to get the attribute value from"}, - {"name": "manager", "type": BuildManager, "doc": "the BuildManager used for managing this build"}, - returns='the value of the attribute') + @docval( + {"name": "spec", "type": Spec, "doc": "the spec to get the attribute value for"}, + {"name": "container", "type": DynamicTable, "doc": "the container to get the attribute value from"}, + {"name": "manager", "type": BuildManager, "doc": "the BuildManager used for managing this build"}, + returns="the value of the attribute", + ) def get_attr_value(self, **kwargs): - ''' Get the value of the attribute corresponding to this spec from the given container ''' - spec, container, manager = getargs('spec', 'container', 'manager', kwargs) + """Get the value of the attribute corresponding to this spec from the given container""" + spec, container, manager = getargs("spec", "container", "manager", kwargs) attr_value = super().get_attr_value(spec, container, manager) if attr_value is None and spec.name in container: - if spec.data_type_inc == 'VectorData': + if spec.data_type_inc == "VectorData": attr_value = container[spec.name] if isinstance(attr_value, VectorIndex): attr_value = attr_value.target - elif spec.data_type_inc == 'DynamicTableRegion': + elif spec.data_type_inc == "DynamicTableRegion": attr_value = container[spec.name] if isinstance(attr_value, VectorIndex): attr_value = attr_value.target if attr_value.table is None: - msg = "empty or missing table for DynamicTableRegion '%s' in DynamicTable '%s'" % \ - (attr_value.name, container.name) + msg = "empty or missing table for DynamicTableRegion '%s' in DynamicTable '%s'" % ( + attr_value.name, + container.name, + ) raise ValueError(msg) - elif spec.data_type_inc == 'VectorIndex': + elif spec.data_type_inc == "VectorIndex": attr_value = container[spec.name] return attr_value class DynamicTableGenerator(CustomClassGenerator): - @classmethod def apply_generator_to_field(cls, field_spec, bases, type_map): """Return True if this is a DynamicTable and the field spec is a column.""" @@ -59,7 +61,16 @@ def apply_generator_to_field(cls, field_spec, bases, type_map): return isinstance(dtype, type) and issubclass(dtype, VectorData) @classmethod - def process_field_spec(cls, classdict, docval_args, parent_cls, attr_name, not_inherited_fields, type_map, spec): + def process_field_spec( + cls, + classdict, + docval_args, + parent_cls, + attr_name, + not_inherited_fields, + type_map, + spec, + ): """Add __columns__ to the classdict and update the docval args for the field spec with the given attribute name. :param classdict: The dict to update with __columns__. :param docval_args: The list of docval arguments. @@ -69,33 +80,33 @@ def process_field_spec(cls, classdict, docval_args, parent_cls, attr_name, not_i :param type_map: The type map to use. :param spec: The spec for the container class to generate. """ - if attr_name.endswith('_index'): # do not add index columns to __columns__ + if attr_name.endswith("_index"): # do not add index columns to __columns__ return field_spec = not_inherited_fields[attr_name] column_conf = dict( name=attr_name, - description=field_spec['doc'], - required=field_spec.required + description=field_spec["doc"], + required=field_spec.required, ) dtype = cls._get_type(field_spec, type_map) if issubclass(dtype, DynamicTableRegion): # the spec does not know which table this DTR points to # the user must specify the table attribute on the DTR after it is generated - column_conf['table'] = True + column_conf["table"] = True else: - column_conf['class'] = dtype + column_conf["class"] = dtype index_counter = 0 index_name = attr_name - while '{}_index'.format(index_name) in not_inherited_fields: # an index column exists for this column - index_name = '{}_index'.format(index_name) + while "{}_index".format(index_name) in not_inherited_fields: # an index column exists for this column + index_name = "{}_index".format(index_name) index_counter += 1 if index_counter == 1: - column_conf['index'] = True + column_conf["index"] = True elif index_counter > 1: - column_conf['index'] = index_counter + column_conf["index"] = index_counter - classdict.setdefault('__columns__', list()).append(column_conf) + classdict.setdefault("__columns__", list()).append(column_conf) # do not add DynamicTable columns to init docval @@ -108,33 +119,36 @@ def post_process(cls, classdict, bases, docval_args, spec): :param spec: The spec for the container class to generate. """ # convert classdict['__columns__'] from list to tuple if present - columns = classdict.get('__columns__') + columns = classdict.get("__columns__") if columns is not None: - classdict['__columns__'] = tuple(columns) + classdict["__columns__"] = tuple(columns) @classmethod def set_init(cls, classdict, bases, docval_args, not_inherited_fields, name): - if '__columns__' not in classdict: + if "__columns__" not in classdict: return - base_init = classdict.get('__init__') + base_init = classdict.get("__init__") if base_init is None: # pragma: no cover raise ValueError("Generated class dictionary is missing base __init__ method.") # add a specialized docval arg for __init__ for specifying targets for DTRs docval_args_local = docval_args.copy() target_tables_dvarg = dict( - name='target_tables', - doc=('dict mapping DynamicTableRegion column name to the table that the DTR points to. The column is ' - 'added to the table if it is not already present (i.e., when it is optional).'), + name="target_tables", + doc=( + "dict mapping DynamicTableRegion column name to the table that the DTR" + " points to. The column is added to the table if it is not already" + " present (i.e., when it is optional)." + ), type=dict, - default=None + default=None, ) cls._add_to_docval_args(docval_args_local, target_tables_dvarg, err_if_present=True) @docval(*docval_args_local, allow_positional=AllowPositional.WARNING) def __init__(self, **kwargs): - target_tables = popargs('target_tables', kwargs) + target_tables = popargs("target_tables", kwargs) base_init(self, **kwargs) # set target attribute on DTR @@ -143,23 +157,27 @@ def __init__(self, **kwargs): if colname not in self: # column has not yet been added (it is optional) column_conf = None for conf in self.__columns__: - if conf['name'] == colname: + if conf["name"] == colname: column_conf = conf break if column_conf is None: - raise ValueError("'%s' is not the name of a predefined column of table %s." - % (colname, self)) - if not column_conf.get('table', False): - raise ValueError("Column '%s' must be a DynamicTableRegion to have a target table." - % colname) - self.add_column(name=column_conf['name'], - description=column_conf['description'], - index=column_conf.get('index', False), - table=True) + raise ValueError( + "'%s' is not the name of a predefined column of table %s." % (colname, self) + ) + if not column_conf.get("table", False): + raise ValueError( + "Column '%s' must be a DynamicTableRegion to have a target table." % colname + ) + self.add_column( + name=column_conf["name"], + description=column_conf["description"], + index=column_conf.get("index", False), + table=True, + ) if isinstance(self[colname], VectorIndex): col = self[colname].target else: col = self[colname] col.table = table - classdict['__init__'] = __init__ + classdict["__init__"] = __init__ diff --git a/src/hdmf/common/multi.py b/src/hdmf/common/multi.py index 598b3cc93..9661a3615 100644 --- a/src/hdmf/common/multi.py +++ b/src/hdmf/common/multi.py @@ -1,23 +1,32 @@ -from . import register_class from ..container import Container, Data, MultiContainerInterface -from ..utils import docval, popargs, AllowPositional +from ..utils import AllowPositional, docval, popargs +from . import register_class -@register_class('SimpleMultiContainer') +@register_class("SimpleMultiContainer") class SimpleMultiContainer(MultiContainerInterface): - __clsconf__ = { - 'attr': 'containers', - 'type': (Container, Data), - 'add': 'add_container', - 'get': 'get_container', + "attr": "containers", + "type": (Container, Data), + "add": "add_container", + "get": "get_container", } - @docval({'name': 'name', 'type': str, 'doc': 'the name of this container'}, - {'name': 'containers', 'type': (list, tuple), 'default': None, - 'doc': 'the Container or Data objects in this file'}, - allow_positional=AllowPositional.WARNING) + @docval( + { + "name": "name", + "type": str, + "doc": "the name of this container", + }, + { + "name": "containers", + "type": (list, tuple), + "default": None, + "doc": "the Container or Data objects in this file", + }, + allow_positional=AllowPositional.WARNING, + ) def __init__(self, **kwargs): - containers = popargs('containers', kwargs) + containers = popargs("containers", kwargs) super().__init__(**kwargs) self.containers = containers diff --git a/src/hdmf/common/resources.py b/src/hdmf/common/resources.py index 1f1e3b1c9..417fc79a7 100644 --- a/src/hdmf/common/resources.py +++ b/src/hdmf/common/resources.py @@ -1,12 +1,19 @@ -import pandas as pd +import os +from glob import glob + import numpy as np -from . import register_class, EXP_NAMESPACE -from . import get_type_map -from ..container import Table, Row, Container, AbstractContainer, ExternalResourcesManager -from ..utils import docval, popargs, AllowPositional +import pandas as pd + from ..build import TypeMap -from glob import glob -import os +from ..container import ( + AbstractContainer, + Container, + ExternalResourcesManager, + Row, + Table, +) +from ..utils import AllowPositional, docval, popargs +from . import EXP_NAMESPACE, get_type_map, register_class class KeyTable(Table): @@ -14,11 +21,14 @@ class KeyTable(Table): A table for storing keys used to reference external resources. """ - __defaultname__ = 'keys' + __defaultname__ = "keys" __columns__ = ( - {'name': 'key', 'type': str, - 'doc': 'The user key that maps to the resource term / registry symbol.'}, + { + "name": "key", + "type": str, + "doc": "The user key that maps to the resource term / registry symbol.", + }, ) @@ -35,16 +45,24 @@ class EntityTable(Table): A table for storing the external resources a key refers to. """ - __defaultname__ = 'entities' + __defaultname__ = "entities" __columns__ = ( - {'name': 'keys_idx', 'type': (int, Key), - 'doc': ('The index into the keys table for the user key that ' - 'maps to the resource term / registry symbol.')}, - {'name': 'entity_id', 'type': str, - 'doc': 'The unique ID for the resource term / registry symbol.'}, - {'name': 'entity_uri', 'type': str, - 'doc': 'The URI for the resource term / registry symbol.'}, + { + "name": "keys_idx", + "type": (int, Key), + "doc": "The index into the keys table for the user key that maps to the resource term / registry symbol.", + }, + { + "name": "entity_id", + "type": str, + "doc": "The unique ID for the resource term / registry symbol.", + }, + { + "name": "entity_uri", + "type": str, + "doc": "The URI for the resource term / registry symbol.", + }, ) @@ -61,12 +79,9 @@ class FileTable(Table): A table for storing file ids used in external resources. """ - __defaultname__ = 'files' + __defaultname__ = "files" - __columns__ = ( - {'name': 'file_object_id', 'type': str, - 'doc': 'The file id of the file that contains the object'}, - ) + __columns__ = ({"name": "file_object_id", "type": str, "doc": "The file id of the file that contains the object"},) class File(Row): @@ -82,21 +97,39 @@ class ObjectTable(Table): A table for storing objects (i.e. Containers) that contain keys that refer to external resources. """ - __defaultname__ = 'objects' + __defaultname__ = "objects" __columns__ = ( - {'name': 'files_idx', 'type': int, - 'doc': 'The row idx for the file_object_id in FileTable containing the object.'}, - {'name': 'object_id', 'type': str, - 'doc': 'The object ID for the Container/Data.'}, - {'name': 'object_type', 'type': str, - 'doc': 'The type of the object. This is also the parent in relative_path.'}, - {'name': 'relative_path', 'type': str, - 'doc': ('The relative_path of the attribute of the object that uses ', - 'an external resource reference key. Use an empty string if not applicable.')}, - {'name': 'field', 'type': str, - 'doc': ('The field of the compound data type using an external resource. ' - 'Use an empty string if not applicable.')} + { + "name": "files_idx", + "type": int, + "doc": "The row idx for the file_object_id in FileTable containing the object.", + }, + { + "name": "object_id", + "type": str, + "doc": "The object ID for the Container/Data.", + }, + { + "name": "object_type", + "type": str, + "doc": "The type of the object. This is also the parent in relative_path.", + }, + { + "name": "relative_path", + "type": str, + "doc": ( + "The relative_path of the attribute of the object that uses ", + "an external resource reference key. Use an empty string if not applicable.", + ), + }, + { + "name": "field", + "type": str, + "doc": ( + "The field of the compound data type using an external resource. Use an empty string if not applicable." + ), + }, ) @@ -113,13 +146,19 @@ class ObjectKeyTable(Table): A table for identifying which keys are used by which objects for referring to external resources. """ - __defaultname__ = 'object_keys' + __defaultname__ = "object_keys" __columns__ = ( - {'name': 'objects_idx', 'type': (int, Object), - 'doc': 'The index into the objects table for the Object that uses the Key.'}, - {'name': 'keys_idx', 'type': (int, Key), - 'doc': 'The index into the keys table that is used to make an external resource reference.'} + { + "name": "objects_idx", + "type": (int, Object), + "doc": "The index into the objects table for the Object that uses the Key.", + }, + { + "name": "keys_idx", + "type": (int, Key), + "doc": "The index into the keys table that is used to make an external resource reference.", + }, ) @@ -131,40 +170,66 @@ class ObjectKey(Row): __table__ = ObjectKeyTable -@register_class('ExternalResources', EXP_NAMESPACE) +@register_class("ExternalResources", EXP_NAMESPACE) class ExternalResources(Container): """A table for mapping user terms (i.e. keys) to resource entities.""" __fields__ = ( - {'name': 'keys', 'child': True}, - {'name': 'files', 'child': True}, - {'name': 'objects', 'child': True}, - {'name': 'object_keys', 'child': True}, - {'name': 'entities', 'child': True}, + {"name": "keys", "child": True}, + {"name": "files", "child": True}, + {"name": "objects", "child": True}, + {"name": "object_keys", "child": True}, + {"name": "entities", "child": True}, ) - @docval({'name': 'keys', 'type': KeyTable, 'default': None, - 'doc': 'The table storing user keys for referencing resources.'}, - {'name': 'files', 'type': FileTable, 'default': None, - 'doc': 'The table for storing file ids used in external resources.'}, - {'name': 'entities', 'type': EntityTable, 'default': None, - 'doc': 'The table storing entity information.'}, - {'name': 'objects', 'type': ObjectTable, 'default': None, - 'doc': 'The table storing object information.'}, - {'name': 'object_keys', 'type': ObjectKeyTable, 'default': None, - 'doc': 'The table storing object-resource relationships.'}, - {'name': 'type_map', 'type': TypeMap, 'default': None, - 'doc': 'The type map. If None is provided, the HDMF-common type map will be used.'}, - allow_positional=AllowPositional.WARNING) + @docval( + { + "name": "keys", + "type": KeyTable, + "default": None, + "doc": "The table storing user keys for referencing resources.", + }, + { + "name": "files", + "type": FileTable, + "default": None, + "doc": "The table for storing file ids used in external resources.", + }, + { + "name": "entities", + "type": EntityTable, + "default": None, + "doc": "The table storing entity information.", + }, + { + "name": "objects", + "type": ObjectTable, + "default": None, + "doc": "The table storing object information.", + }, + { + "name": "object_keys", + "type": ObjectKeyTable, + "default": None, + "doc": "The table storing object-resource relationships.", + }, + { + "name": "type_map", + "type": TypeMap, + "default": None, + "doc": "The type map. If None is provided, the HDMF-common type map will be used.", + }, + allow_positional=AllowPositional.WARNING, + ) def __init__(self, **kwargs): - name = 'external_resources' + name = "external_resources" super().__init__(name) - self.keys = kwargs['keys'] or KeyTable() - self.files = kwargs['files'] or FileTable() - self.entities = kwargs['entities'] or EntityTable() - self.objects = kwargs['objects'] or ObjectTable() - self.object_keys = kwargs['object_keys'] or ObjectKeyTable() - self.type_map = kwargs['type_map'] or get_type_map() + self.keys = kwargs["keys"] or KeyTable() + self.files = kwargs["files"] or FileTable() + self.entities = kwargs["entities"] or EntityTable() + self.objects = kwargs["objects"] or ObjectTable() + self.object_keys = kwargs["object_keys"] or ObjectKeyTable() + self.type_map = kwargs["type_map"] or get_type_map() @staticmethod def assert_external_resources_equal(left, right, check_dtype=True): @@ -184,41 +249,41 @@ def assert_external_resources_equal(left, right, check_dtype=True): """ errors = [] try: - pd.testing.assert_frame_equal(left.keys.to_dataframe(), - right.keys.to_dataframe(), - check_dtype=check_dtype) + pd.testing.assert_frame_equal( + left.keys.to_dataframe(), right.keys.to_dataframe(), check_dtype=check_dtype + ) # fmt: skip except AssertionError as e: errors.append(e) try: - pd.testing.assert_frame_equal(left.files.to_dataframe(), - right.files.to_dataframe(), - check_dtype=check_dtype) + pd.testing.assert_frame_equal( + left.files.to_dataframe(), right.files.to_dataframe(), check_dtype=check_dtype + ) except AssertionError as e: errors.append(e) try: - pd.testing.assert_frame_equal(left.objects.to_dataframe(), - right.objects.to_dataframe(), - check_dtype=check_dtype) + pd.testing.assert_frame_equal( + left.objects.to_dataframe(), right.objects.to_dataframe(), check_dtype=check_dtype + ) except AssertionError as e: errors.append(e) try: - pd.testing.assert_frame_equal(left.entities.to_dataframe(), - right.entities.to_dataframe(), - check_dtype=check_dtype) + pd.testing.assert_frame_equal( + left.entities.to_dataframe(), right.entities.to_dataframe(), check_dtype=check_dtype + ) except AssertionError as e: errors.append(e) try: - pd.testing.assert_frame_equal(left.object_keys.to_dataframe(), - right.object_keys.to_dataframe(), - check_dtype=check_dtype) + pd.testing.assert_frame_equal( + left.object_keys.to_dataframe(), right.object_keys.to_dataframe(), check_dtype=check_dtype + ) except AssertionError as e: errors.append(e) if len(errors) > 0: - msg = ''.join(str(e)+"\n\n" for e in errors) + msg = "".join(str(e) + "\n\n" for e in errors) raise AssertionError(msg) return True - @docval({'name': 'key_name', 'type': str, 'doc': 'The name of the key to be added.'}) + @docval({"name": "key_name", "type": str, "doc": "The name of the key to be added."}) def _add_key(self, **kwargs): """ Add a key to be used for making references to external resources. @@ -230,55 +295,78 @@ def _add_key(self, **kwargs): The returned Key objects must be managed by the caller so as to be appropriately passed to subsequent calls to methods for storing information about the different resources. """ - key = kwargs['key_name'] + key = kwargs["key_name"] return Key(key, table=self.keys) - @docval({'name': 'file_object_id', 'type': str, 'doc': 'The id of the file'}) + @docval({"name": "file_object_id", "type": str, "doc": "The id of the file"}) def _add_file(self, **kwargs): """ Add a file to be used for making references to external resources. This is optional when working in HDMF. """ - file_object_id = kwargs['file_object_id'] + file_object_id = kwargs["file_object_id"] return File(file_object_id, table=self.files) - @docval({'name': 'key', 'type': (str, Key), 'doc': 'The key to associate the entity with.'}, - {'name': 'entity_id', 'type': str, 'doc': 'The unique entity id.'}, - {'name': 'entity_uri', 'type': str, 'doc': 'The URI for the entity.'}) + @docval( + {"name": "key", "type": (str, Key), "doc": "The key to associate the entity with."}, + {"name": "entity_id", "type": str, "doc": "The unique entity id."}, + {"name": "entity_uri", "type": str, "doc": "The URI for the entity."}, + ) def _add_entity(self, **kwargs): """ Add an entity that will be referenced to using the given key. """ - key = kwargs['key'] - entity_id = kwargs['entity_id'] - entity_uri = kwargs['entity_uri'] + key = kwargs["key"] + entity_id = kwargs["entity_id"] + entity_uri = kwargs["entity_uri"] if not isinstance(key, Key): key = self._add_key(key) entity = Entity(key, entity_id, entity_uri, table=self.entities) return entity - @docval({'name': 'container', 'type': (str, AbstractContainer), - 'doc': 'The Container/Data object to add or the object id of the Container/Data object to add.'}, - {'name': 'files_idx', 'type': int, - 'doc': 'The file_object_id row idx.'}, - {'name': 'object_type', 'type': str, 'default': None, - 'doc': ('The type of the object. This is also the parent in relative_path. If omitted, ' - 'the name of the container class is used.')}, - {'name': 'relative_path', 'type': str, - 'doc': ('The relative_path of the attribute of the object that uses ', - 'an external resource reference key. Use an empty string if not applicable.')}, - {'name': 'field', 'type': str, 'default': '', - 'doc': ('The field of the compound data type using an external resource.')}) + @docval( + { + "name": "container", + "type": (str, AbstractContainer), + "doc": "The Container/Data object to add or the object id of the Container/Data object to add.", + }, + { + "name": "files_idx", + "type": int, + "doc": "The file_object_id row idx.", + }, + { + "name": "object_type", + "type": str, + "default": None, + "doc": ( + "The type of the object. This is also the parent in relative_path. If omitted, " + "the name of the container class is used." + ), + }, + { + "name": "relative_path", + "type": str, + "doc": ( + "The relative_path of the attribute of the object that uses ", + "an external resource reference key. Use an empty string if not applicable.", + ), + }, + { + "name": "field", + "type": str, + "default": "", + "doc": "The field of the compound data type using an external resource.", + }, + ) def _add_object(self, **kwargs): """ Add an object that references an external resource. """ - files_idx, container, object_type, relative_path, field = popargs('files_idx', - 'container', - 'object_type', - 'relative_path', - 'field', kwargs) + files_idx, container, object_type, relative_path, field = popargs( + "files_idx", "container", "object_type", "relative_path", "field", kwargs + ) if object_type is None: object_type = container.__class__.__name__ @@ -288,27 +376,53 @@ def _add_object(self, **kwargs): obj = Object(files_idx, container, object_type, relative_path, field, table=self.objects) return obj - @docval({'name': 'obj', 'type': (int, Object), 'doc': 'The Object that uses the Key.'}, - {'name': 'key', 'type': (int, Key), 'doc': 'The Key that the Object uses.'}) + @docval( + {"name": "obj", "type": (int, Object), "doc": "The Object that uses the Key."}, + {"name": "key", "type": (int, Key), "doc": "The Key that the Object uses."}, + ) def _add_object_key(self, **kwargs): """ Specify that an object (i.e. container and relative_path) uses a key to reference an external resource. """ - obj, key = popargs('obj', 'key', kwargs) + obj, key = popargs("obj", "key", kwargs) return ObjectKey(obj, key, table=self.object_keys) - @docval({'name': 'file', 'type': ExternalResourcesManager, 'doc': 'The file associated with the container.'}, - {'name': 'container', 'type': AbstractContainer, - 'doc': ('The Container/Data object that uses the key or ' - 'the object id for the Container/Data object that uses the key.')}, - {'name': 'relative_path', 'type': str, - 'doc': ('The relative_path of the attribute of the object that uses ', - 'an external resource reference key. Use an empty string if not applicable.'), - 'default': ''}, - {'name': 'field', 'type': str, 'default': '', - 'doc': ('The field of the compound data type using an external resource.')}, - {'name': 'create', 'type': bool, 'default': True}) + @docval( + { + "name": "file", + "type": ExternalResourcesManager, + "doc": "The file associated with the container.", + }, + { + "name": "container", + "type": AbstractContainer, + "doc": ( + "The Container/Data object that uses the key or " + "the object id for the Container/Data object that uses the key." + ), + }, + { + "name": "relative_path", + "type": str, + "doc": ( + "The relative_path of the attribute of the object that uses ", + "an external resource reference key. Use an empty string if not applicable.", + ), + "default": "", + }, + { + "name": "field", + "type": str, + "default": "", + "doc": "The field of the compound data type using an external resource.", + }, + { + "name": "create", + "type": bool, + "default": True, + }, + ) def _check_object_field(self, **kwargs): """ Check if a container, relative path, and field have been added. @@ -318,11 +432,11 @@ def _check_object_field(self, **kwargs): If the container, relative_path, and field have not been added, add them and return the corresponding Object. Otherwise, just return the Object. """ - file = kwargs['file'] - container = kwargs['container'] - relative_path = kwargs['relative_path'] - field = kwargs['field'] - create = kwargs['create'] + file = kwargs["file"] + container = kwargs["container"] + relative_path = kwargs["relative_path"] + field = kwargs["field"] + create = kwargs["create"] file_object_id = file.object_id files_idx = self.files.which(file_object_id=file_object_id) @@ -347,17 +461,25 @@ def _check_object_field(self, **kwargs): elif len(objecttable_idx) == 0 and not create: raise ValueError("Object not in Object Table.") else: - raise ValueError("Found multiple instances of the same object id, relative path, " - "and field in objects table.") + raise ValueError( + "Found multiple instances of the same object id, relative path, and field in objects table." + ) - @docval({'name': 'container', 'type': (str, AbstractContainer), - 'doc': ('The Container/Data object that uses the key or ' - 'the object id for the Container/Data object that uses the key.')}) + @docval( + { + "name": "container", + "type": (str, AbstractContainer), + "doc": ( + "The Container/Data object that uses the key or " + "the object id for the Container/Data object that uses the key." + ), + } + ) def _get_file_from_container(self, **kwargs): """ Method to retrieve a file associated with the container in the case a file is not provided. """ - container = kwargs['container'] + container = kwargs["container"] if isinstance(container, ExternalResourcesManager): file = container @@ -372,21 +494,46 @@ def _get_file_from_container(self, **kwargs): else: parent = parent.parent else: - msg = 'Could not find file. Add container to the file.' + msg = "Could not find file. Add container to the file." raise ValueError(msg) - @docval({'name': 'key_name', 'type': str, 'doc': 'The name of the Key to get.'}, - {'name': 'file', 'type': ExternalResourcesManager, 'doc': 'The file associated with the container.', - 'default': None}, - {'name': 'container', 'type': (str, AbstractContainer), 'default': None, - 'doc': ('The Container/Data object that uses the key or ' - 'the object id for the Container/Data object that uses the key.')}, - {'name': 'relative_path', 'type': str, - 'doc': ('The relative_path of the attribute of the object that uses ', - 'an external resource reference key. Use an empty string if not applicable.'), - 'default': ''}, - {'name': 'field', 'type': str, 'default': '', - 'doc': ('The field of the compound data type using an external resource.')}) + @docval( + { + "name": "key_name", + "type": str, + "doc": "The name of the Key to get.", + }, + { + "name": "file", + "type": ExternalResourcesManager, + "doc": "The file associated with the container.", + "default": None, + }, + { + "name": "container", + "type": (str, AbstractContainer), + "default": None, + "doc": ( + "The Container/Data object that uses the key or " + "the object id for the Container/Data object that uses the key." + ), + }, + { + "name": "relative_path", + "type": str, + "doc": ( + "The relative_path of the attribute of the object that uses ", + "an external resource reference key. Use an empty string if not applicable.", + ), + "default": "", + }, + { + "name": "field", + "type": str, + "default": "", + "doc": "The field of the compound data type using an external resource.", + }, + ) def get_key(self, **kwargs): """ Return a Key. @@ -394,22 +541,21 @@ def get_key(self, **kwargs): If container, relative_path, and field are provided, the Key that corresponds to the given name of the key for the given container, relative_path, and field is returned. """ - key_name, container, relative_path, field = popargs('key_name', 'container', 'relative_path', 'field', kwargs) + key_name, container, relative_path, field = popargs("key_name", "container", "relative_path", "field", kwargs) key_idx_matches = self.keys.which(key=key_name) - file = kwargs['file'] + file = kwargs["file"] if container is not None: if file is None: file = self._get_file_from_container(container=container) # if same key is used multiple times, determine # which instance based on the Container - object_field = self._check_object_field(file=file, - container=container, - relative_path=relative_path, - field=field) + object_field = self._check_object_field( + file=file, container=container, relative_path=relative_path, field=field + ) for row_idx in self.object_keys.which(objects_idx=object_field.idx): - key_idx = self.object_keys['keys_idx', row_idx] + key_idx = self.object_keys["keys_idx", row_idx] if key_idx in key_idx_matches: return self.keys.row[key_idx] msg = "No key found with that container." @@ -424,20 +570,51 @@ def get_key(self, **kwargs): else: return self.keys.row[key_idx_matches[0]] - @docval({'name': 'container', 'type': (str, AbstractContainer), 'default': None, - 'doc': ('The Container/Data object that uses the key or ' - 'the object_id for the Container/Data object that uses the key.')}, - {'name': 'attribute', 'type': str, - 'doc': 'The attribute of the container for the external reference.', 'default': None}, - {'name': 'field', 'type': str, 'default': '', - 'doc': ('The field of the compound data type using an external resource.')}, - {'name': 'key', 'type': (str, Key), 'default': None, - 'doc': 'The name of the key or the Key object from the KeyTable for the key to add a resource for.'}, - {'name': 'entity_id', 'type': str, 'doc': 'The identifier for the entity at the resource.'}, - {'name': 'entity_uri', 'type': str, 'doc': 'The URI for the identifier at the resource.'}, - {'name': 'file', 'type': ExternalResourcesManager, 'doc': 'The file associated with the container.', - 'default': None}, - ) + @docval( + { + "name": "container", + "type": (str, AbstractContainer), + "default": None, + "doc": ( + "The Container/Data object that uses the key or " + "the object_id for the Container/Data object that uses the key." + ), + }, + { + "name": "attribute", + "type": str, + "doc": "The attribute of the container for the external reference.", + "default": None, + }, + { + "name": "field", + "type": str, + "default": "", + "doc": "The field of the compound data type using an external resource.", + }, + { + "name": "key", + "type": (str, Key), + "default": None, + "doc": "The name of the key or the Key object from the KeyTable for the key to add a resource for.", + }, + { + "name": "entity_id", + "type": str, + "doc": "The identifier for the entity at the resource.", + }, + { + "name": "entity_uri", + "type": str, + "doc": "The URI for the identifier at the resource.", + }, + { + "name": "file", + "type": ExternalResourcesManager, + "doc": "The file associated with the container.", + "default": None, + }, + ) def add_ref(self, **kwargs): """ Add information about an external reference used in this file. @@ -447,31 +624,29 @@ def add_ref(self, **kwargs): field combination. This method does not support such functionality by default. """ ############################################################### - container = kwargs['container'] - attribute = kwargs['attribute'] - key = kwargs['key'] - field = kwargs['field'] - entity_id = kwargs['entity_id'] - entity_uri = kwargs['entity_uri'] - file = kwargs['file'] + container = kwargs["container"] + attribute = kwargs["attribute"] + key = kwargs["key"] + field = kwargs["field"] + entity_id = kwargs["entity_id"] + entity_uri = kwargs["entity_uri"] + file = kwargs["file"] if file is None: file = self._get_file_from_container(container=container) if attribute is None: # Trivial Case - relative_path = '' - object_field = self._check_object_field(file=file, - container=container, - relative_path=relative_path, - field=field) + relative_path = "" + object_field = self._check_object_field( + file=file, container=container, relative_path=relative_path, field=field + ) else: # DataType Attribute Case attribute_object = getattr(container, attribute) # returns attribute object if isinstance(attribute_object, AbstractContainer): - relative_path = '' - object_field = self._check_object_field(file=file, - container=attribute_object, - relative_path=relative_path, - field=field) + relative_path = "" + object_field = self._check_object_field( + file=file, container=attribute_object, relative_path=relative_path, field=field + ) else: # Non-DataType Attribute Case: obj_mapper = self.type_map.get_map(container) spec = obj_mapper.get_attr_spec(attr_name=attribute) @@ -484,30 +659,28 @@ def add_ref(self, **kwargs): parent = container # We need to get the path of the spec for relative_path absolute_path = spec.path - relative_path = absolute_path[absolute_path.find('/')+1:] - object_field = self._check_object_field(file=file, - container=parent, - relative_path=relative_path, - field=field) + relative_path = absolute_path[absolute_path.find("/") + 1 :] + object_field = self._check_object_field( + file=file, container=parent, relative_path=relative_path, field=field + ) else: - msg = 'Container not the nearest data_type' + msg = "Container not the nearest data_type" raise ValueError(msg) else: parent = container # container needs to be the parent absolute_path = spec.path - relative_path = absolute_path[absolute_path.find('/')+1:] + relative_path = absolute_path[absolute_path.find("/") + 1 :] # this regex removes everything prior to the container on the absolute_path - object_field = self._check_object_field(file=file, - container=parent, - relative_path=relative_path, - field=field) + object_field = self._check_object_field( + file=file, container=parent, relative_path=relative_path, field=field + ) if not isinstance(key, Key): key_idx_matches = self.keys.which(key=key) - # if same key is used multiple times, determine - # which instance based on the Container + # if same key is used multiple times, determine + # which instance based on the Container for row_idx in self.object_keys.which(objects_idx=object_field.idx): - key_idx = self.object_keys['keys_idx', row_idx] + key_idx = self.object_keys["keys_idx", row_idx] if key_idx in key_idx_matches: msg = "Use Key Object when referencing an existing (container, relative_path, key)" raise ValueError(msg) @@ -520,57 +693,99 @@ def add_ref(self, **kwargs): return key, entity - @docval({'name': 'object_type', 'type': str, - 'doc': 'The type of the object. This is also the parent in relative_path.'}, - {'name': 'relative_path', 'type': str, - 'doc': ('The relative_path of the attribute of the object that uses ', - 'an external resource reference key. Use an empty string if not applicable.'), - 'default': ''}, - {'name': 'field', 'type': str, 'default': '', - 'doc': ('The field of the compound data type using an external resource.')}, - {'name': 'all_instances', 'type': bool, 'default': False, - 'doc': ('The bool to return a dataframe with all instances of the object_type.', - 'If True, relative_path and field inputs will be ignored.')}) + @docval( + { + "name": "object_type", + "type": str, + "doc": "The type of the object. This is also the parent in relative_path.", + }, + { + "name": "relative_path", + "type": str, + "doc": ( + "The relative_path of the attribute of the object that uses ", + "an external resource reference key. Use an empty string if not applicable.", + ), + "default": "", + }, + { + "name": "field", + "type": str, + "default": "", + "doc": "The field of the compound data type using an external resource.", + }, + { + "name": "all_instances", + "type": bool, + "default": False, + "doc": ( + "The bool to return a dataframe with all instances of the object_type.", + "If True, relative_path and field inputs will be ignored.", + ), + }, + ) def get_object_type(self, **kwargs): """ Get all entities/resources associated with an object_type. """ - object_type = kwargs['object_type'] - relative_path = kwargs['relative_path'] - field = kwargs['field'] - all_instances = kwargs['all_instances'] + object_type = kwargs["object_type"] + relative_path = kwargs["relative_path"] + field = kwargs["field"] + all_instances = kwargs["all_instances"] df = self.to_dataframe() if all_instances: - df = df.loc[df['object_type'] == object_type] + df = df.loc[df["object_type"] == object_type] else: - df = df.loc[(df['object_type'] == object_type) - & (df['relative_path'] == relative_path) - & (df['field'] == field)] + df = df.loc[ + (df["object_type"] == object_type) & (df["relative_path"] == relative_path) & (df["field"] == field) + ] return df - @docval({'name': 'file', 'type': ExternalResourcesManager, 'doc': 'The file.', - 'default': None}, - {'name': 'container', 'type': (str, AbstractContainer), - 'doc': 'The Container/data object that is linked to resources/entities.'}, - {'name': 'attribute', 'type': str, - 'doc': 'The attribute of the container for the external reference.', 'default': None}, - {'name': 'relative_path', 'type': str, - 'doc': ('The relative_path of the attribute of the object that uses ', - 'an external resource reference key. Use an empty string if not applicable.'), - 'default': ''}, - {'name': 'field', 'type': str, 'default': '', - 'doc': ('The field of the compound data type using an external resource.')}) + @docval( + { + "name": "file", + "type": ExternalResourcesManager, + "doc": "The file.", + "default": None, + }, + { + "name": "container", + "type": (str, AbstractContainer), + "doc": "The Container/data object that is linked to resources/entities.", + }, + { + "name": "attribute", + "type": str, + "doc": "The attribute of the container for the external reference.", + "default": None, + }, + { + "name": "relative_path", + "type": str, + "doc": ( + "The relative_path of the attribute of the object that uses ", + "an external resource reference key. Use an empty string if not applicable.", + ), + "default": "", + }, + { + "name": "field", + "type": str, + "default": "", + "doc": "The field of the compound data type using an external resource.", + }, + ) def get_object_entities(self, **kwargs): """ Get all entities/resources associated with an object. """ - file = kwargs['file'] - container = kwargs['container'] - attribute = kwargs['attribute'] - relative_path = kwargs['relative_path'] - field = kwargs['field'] + file = kwargs["file"] + container = kwargs["container"] + attribute = kwargs["attribute"] + relative_path = kwargs["relative_path"] + field = kwargs["field"] if file is None: file = self._get_file_from_container(container=container) @@ -578,38 +793,41 @@ def get_object_entities(self, **kwargs): keys = [] entities = [] if attribute is None: - object_field = self._check_object_field(file=file, - container=container, - relative_path=relative_path, - field=field, - create=False) + object_field = self._check_object_field( + file=file, container=container, relative_path=relative_path, field=field, create=False + ) else: - object_field = self._check_object_field(file=file, - container=container[attribute], - relative_path=relative_path, - field=field, - create=False) + object_field = self._check_object_field( + file=file, container=container[attribute], relative_path=relative_path, field=field, create=False + ) # Find all keys associated with the object for row_idx in self.object_keys.which(objects_idx=object_field.idx): - keys.append(self.object_keys['keys_idx', row_idx]) + keys.append(self.object_keys["keys_idx", row_idx]) # Find all the entities/resources for each key. for key_idx in keys: entity_idx = self.entities.which(keys_idx=key_idx) entities.append(list(self.entities.__getitem__(entity_idx[0]))) - df = pd.DataFrame(entities, columns=['keys_idx', 'entity_id', 'entity_uri']) + df = pd.DataFrame(entities, columns=["keys_idx", "entity_id", "entity_uri"]) key_names = [] - for idx in df['keys_idx']: - key_id_val = self.keys.to_dataframe().iloc[int(idx)]['key'] + for idx in df["keys_idx"]: + key_id_val = self.keys.to_dataframe().iloc[int(idx)]["key"] key_names.append(key_id_val) - df['keys_idx'] = key_names - df = df.rename(columns={'keys_idx': 'key_names', 'entity_id': 'entity_id', 'entity_uri': 'entity_uri'}) + df["keys_idx"] = key_names + df = df.rename(columns={"keys_idx": "key_names", "entity_id": "entity_id", "entity_uri": "entity_uri"}) return df - @docval({'name': 'use_categories', 'type': bool, 'default': False, - 'doc': 'Use a multi-index on the columns to indicate which category each column belongs to.'}, - rtype=pd.DataFrame, returns='A DataFrame with all data merged into a flat, denormalized table.') + @docval( + { + "name": "use_categories", + "type": bool, + "default": False, + "doc": "Use a multi-index on the columns to indicate which category each column belongs to.", + }, + rtype=pd.DataFrame, + returns="A DataFrame with all data merged into a flat, denormalized table.", + ) def to_dataframe(self, **kwargs): """ Convert the data from the keys, resources, entities, objects, and object_keys tables @@ -620,153 +838,162 @@ def to_dataframe(self, **kwargs): Returns: :py:class:`~pandas.DataFrame` with all data merged into a single, flat, denormalized table. """ - use_categories = popargs('use_categories', kwargs) + use_categories = popargs("use_categories", kwargs) # Step 1: Combine the entities, keys, and files table entities_df = self.entities.to_dataframe() # Map the keys to the entities by 1) convert to dataframe, 2) select rows based on the keys_idx # from the entities table, expanding the dataframe to have the same number of rows as the # entities, and 3) reset the index to avoid duplicate values in the index, which causes errors when merging - keys_mapped_df = self.keys.to_dataframe().iloc[entities_df['keys_idx']].reset_index(drop=True) + keys_mapped_df = self.keys.to_dataframe().iloc[entities_df["keys_idx"]].reset_index(drop=True) # Map the resources to entities using the same strategy as for the keys # resources_mapped_df = self.resources.to_dataframe().iloc[entities_df['resources_idx']].reset_index(drop=True) # Merge the mapped keys and resources with the entities tables - entities_df = pd.concat(objs=[entities_df, keys_mapped_df], - axis=1, verify_integrity=False) + entities_df = pd.concat(objs=[entities_df, keys_mapped_df], axis=1, verify_integrity=False) # Add a column for the entity id (for consistency with the other tables and to facilitate query) - entities_df['entities_idx'] = entities_df.index + entities_df["entities_idx"] = entities_df.index # Step 2: Combine the the files, object_keys and objects tables object_keys_df = self.object_keys.to_dataframe() - objects_mapped_df = self.objects.to_dataframe().iloc[object_keys_df['objects_idx']].reset_index(drop=True) - object_keys_df = pd.concat(objs=[object_keys_df, objects_mapped_df], - axis=1, - verify_integrity=False) - files_df = self.files.to_dataframe().iloc[object_keys_df['files_idx']].reset_index(drop=True) - file_object_object_key_df = pd.concat(objs=[object_keys_df, files_df], - axis=1, - verify_integrity=False) + objects_mapped_df = self.objects.to_dataframe().iloc[object_keys_df["objects_idx"]].reset_index(drop=True) + object_keys_df = pd.concat(objs=[object_keys_df, objects_mapped_df], axis=1, verify_integrity=False) + files_df = self.files.to_dataframe().iloc[object_keys_df["files_idx"]].reset_index(drop=True) + file_object_object_key_df = pd.concat(objs=[object_keys_df, files_df], axis=1, verify_integrity=False) # Step 3: merge the combined entities_df and object_keys_df DataFrames result_df = pd.concat( # Create for each row in the objects_keys table a DataFrame with all corresponding data from all tables - objs=[pd.merge( + objs=[ + pd.merge( # Find all entities that correspond to the row i of the object_keys_table - entities_df[entities_df['keys_idx'] == object_keys_df['keys_idx'].iloc[i]].reset_index(drop=True), + entities_df[entities_df["keys_idx"] == object_keys_df["keys_idx"].iloc[i]].reset_index(drop=True), # Get a DataFrame for row i of the objects_keys_table - file_object_object_key_df.iloc[[i, ]], + file_object_object_key_df.iloc[ + [ + i, + ] + ], # Merge the entities and object_keys on the keys_idx column so that the values from the single # object_keys_table row are copied across all corresponding rows in the entities table - on='keys_idx') - for i in range(len(object_keys_df))], + on="keys_idx", + ) + for i in range(len(object_keys_df)) + ], # Concatenate the rows of the objs axis=0, - verify_integrity=False) + verify_integrity=False, + ) # Step 4: Clean up the index and sort columns by table type and name result_df.reset_index(inplace=True, drop=True) # ADD files file_id_col = [] - for idx in result_df['files_idx']: - file_id_val = self.files.to_dataframe().iloc[int(idx)]['file_object_id'] + for idx in result_df["files_idx"]: + file_id_val = self.files.to_dataframe().iloc[int(idx)]["file_object_id"] file_id_col.append(file_id_val) - result_df['file_object_id'] = file_id_col - column_labels = [('files', 'file_object_id'), - ('objects', 'objects_idx'), ('objects', 'object_id'), ('objects', 'files_idx'), - ('objects', 'object_type'), ('objects', 'relative_path'), ('objects', 'field'), - ('keys', 'keys_idx'), ('keys', 'key'), - ('entities', 'entities_idx'), ('entities', 'entity_id'), ('entities', 'entity_uri')] + result_df["file_object_id"] = file_id_col + column_labels = [ + ("files", "file_object_id"), + ("objects", "objects_idx"), + ("objects", "object_id"), + ("objects", "files_idx"), + ("objects", "object_type"), + ("objects", "relative_path"), + ("objects", "field"), + ("keys", "keys_idx"), + ("keys", "key"), + ("entities", "entities_idx"), + ("entities", "entity_id"), + ("entities", "entity_uri"), + ] # sort the columns based on our custom order - result_df = result_df.reindex(labels=[c[1] for c in column_labels], - axis=1) - result_df = result_df.astype({'keys_idx': 'uint32', - 'objects_idx': 'uint32', - 'files_idx': 'uint32', - 'entities_idx': 'uint32'}) + result_df = result_df.reindex(labels=[c[1] for c in column_labels], axis=1) + result_df = result_df.astype( + {"keys_idx": "uint32", "objects_idx": "uint32", "files_idx": "uint32", "entities_idx": "uint32"} + ) # Add the categories if requested if use_categories: result_df.columns = pd.MultiIndex.from_tuples(column_labels) # return the result return result_df - @docval({'name': 'path', 'type': str, 'doc': 'path of the folder tsv file to write'}) + @docval({"name": "path", "type": str, "doc": "path of the folder tsv file to write"}) def to_norm_tsv(self, **kwargs): """ Write the tables in ExternalResources to individual tsv files. """ - folder_path = kwargs['path'] + folder_path = kwargs["path"] for child in self.children: df = child.to_dataframe() - df.to_csv(folder_path+'/'+child.name+'.tsv', sep='\t', index=False) + df.to_csv(folder_path + "/" + child.name + ".tsv", sep="\t", index=False) @classmethod - @docval({'name': 'path', 'type': str, 'doc': 'path of the folder containing the tsv files to read'}, - returns="ExternalResources loaded from TSV", rtype="ExternalResources") + @docval( + {"name": "path", "type": str, "doc": "path of the folder containing the tsv files to read"}, + returns="ExternalResources loaded from TSV", + rtype="ExternalResources", + ) def from_norm_tsv(cls, **kwargs): - path = kwargs['path'] - tsv_paths = glob(path+'/*') + path = kwargs["path"] + tsv_paths = glob(path + "/*") for file in tsv_paths: file_name = os.path.basename(file) - if file_name == 'files.tsv': - files_df = pd.read_csv(file, sep='\t').replace(np.nan, '') - files = FileTable().from_dataframe(df=files_df, name='files', extra_ok=False) + if file_name == "files.tsv": + files_df = pd.read_csv(file, sep="\t").replace(np.nan, "") + files = FileTable().from_dataframe(df=files_df, name="files", extra_ok=False) continue - if file_name == 'keys.tsv': - keys_df = pd.read_csv(file, sep='\t').replace(np.nan, '') - keys = KeyTable().from_dataframe(df=keys_df, name='keys', extra_ok=False) + if file_name == "keys.tsv": + keys_df = pd.read_csv(file, sep="\t").replace(np.nan, "") + keys = KeyTable().from_dataframe(df=keys_df, name="keys", extra_ok=False) continue - if file_name == 'entities.tsv': - entities_df = pd.read_csv(file, sep='\t').replace(np.nan, '') - entities = EntityTable().from_dataframe(df=entities_df, name='entities', extra_ok=False) + if file_name == "entities.tsv": + entities_df = pd.read_csv(file, sep="\t").replace(np.nan, "") + entities = EntityTable().from_dataframe(df=entities_df, name="entities", extra_ok=False) continue - if file_name == 'objects.tsv': - objects_df = pd.read_csv(file, sep='\t').replace(np.nan, '') - objects = ObjectTable().from_dataframe(df=objects_df, name='objects', extra_ok=False) + if file_name == "objects.tsv": + objects_df = pd.read_csv(file, sep="\t").replace(np.nan, "") + objects = ObjectTable().from_dataframe(df=objects_df, name="objects", extra_ok=False) continue - if file_name == 'object_keys.tsv': - object_keys_df = pd.read_csv(file, sep='\t').replace(np.nan, '') - object_keys = ObjectKeyTable().from_dataframe(df=object_keys_df, name='object_keys', extra_ok=False) + if file_name == "object_keys.tsv": + object_keys_df = pd.read_csv(file, sep="\t").replace(np.nan, "") + object_keys = ObjectKeyTable().from_dataframe(df=object_keys_df, name="object_keys", extra_ok=False) continue # we need to check the idx columns in entities, objects, and object_keys - keys_idx = entities['keys_idx'] + keys_idx = entities["keys_idx"] for idx in keys_idx: if not int(idx) < keys.__len__(): msg = "Key Index out of range in EntityTable. Please check for alterations." raise ValueError(msg) - files_idx = objects['files_idx'] + files_idx = objects["files_idx"] for idx in files_idx: if not int(idx) < files.__len__(): msg = "File_ID Index out of range in ObjectTable. Please check for alterations." raise ValueError(msg) - object_idx = object_keys['objects_idx'] + object_idx = object_keys["objects_idx"] for idx in object_idx: if not int(idx) < objects.__len__(): msg = "Object Index out of range in ObjectKeyTable. Please check for alterations." raise ValueError(msg) - keys_idx = object_keys['keys_idx'] + keys_idx = object_keys["keys_idx"] for idx in keys_idx: if not int(idx) < keys.__len__(): msg = "Key Index out of range in ObjectKeyTable. Please check for alterations." raise ValueError(msg) - er = ExternalResources(files=files, - keys=keys, - entities=entities, - objects=objects, - object_keys=object_keys) + er = ExternalResources(files=files, keys=keys, entities=entities, objects=objects, object_keys=object_keys) return er - @docval({'name': 'path', 'type': str, 'doc': 'path of the tsv file to write'}) + @docval({"name": "path", "type": str, "doc": "path of the tsv file to write"}) def to_flat_tsv(self, **kwargs): """ Write ExternalResources as a single, flat table to TSV Internally, the function uses :py:meth:`pandas.DataFrame.to_csv`. Pandas can infer compression based on the filename, i.e., by changing the file extension to - '.gz', '.bz2', '.zip', '.xz', or '.zst' we can write compressed files. + '.gz', '.bz2', '.zip', '.xz', or '.zst', we can write compressed files. The TSV is formatted as follows: 1) line one indicates for each column the name of the table the column belongs to, 2) line two is the name of the column within the table, 3) subsequent lines are each a row in the flattened ExternalResources table. The first column is the @@ -775,13 +1002,16 @@ def to_flat_tsv(self, **kwargs): See also :py:meth:`~hdmf.common.resources.ExternalResources.from_tsv` """ # noqa: E501 - path = popargs('path', kwargs) + path = popargs("path", kwargs) df = self.to_dataframe(use_categories=True) - df.to_csv(path, sep='\t') + df.to_csv(path, sep="\t") @classmethod - @docval({'name': 'path', 'type': str, 'doc': 'path of the tsv file to read'}, - returns="ExternalResources loaded from TSV", rtype="ExternalResources") + @docval( + {"name": "path", "type": str, "doc": "path of the tsv file to read"}, + returns="ExternalResources loaded from TSV", + rtype="ExternalResources", + ) def from_flat_tsv(cls, **kwargs): """ Read ExternalResources from a flat tsv file @@ -806,6 +1036,7 @@ def from_flat_tsv(cls, **kwargs): the TSV without using the :py:meth:`~hdmf.common.resources.ExternalResources` class should be done with great care! """ + def check_idx(idx_arr, name): """Check that indices are consecutively numbered without missing values""" idx_diff = np.diff(idx_arr) @@ -814,68 +1045,68 @@ def check_idx(idx_arr, name): msg = "Missing %s entries %s" % (name, str(missing_idx)) raise ValueError(msg) - path = popargs('path', kwargs) - df = pd.read_csv(path, header=[0, 1], sep='\t').replace(np.nan, '') + path = popargs("path", kwargs) + df = pd.read_csv(path, header=[0, 1], sep="\t").replace(np.nan, "") # Construct the ExternalResources er = ExternalResources() # Retrieve all the Files - files_idx, files_rows = np.unique(df[('objects', 'files_idx')], return_index=True) + files_idx, files_rows = np.unique(df[("objects", "files_idx")], return_index=True) file_order = np.argsort(files_idx) files_idx = files_idx[file_order] files_rows = files_rows[file_order] # Check that files are consecutively numbered - check_idx(idx_arr=files_idx, name='files_idx') - files = df[('files', 'file_object_id')].iloc[files_rows] + check_idx(idx_arr=files_idx, name="files_idx") + files = df[("files", "file_object_id")].iloc[files_rows] for file in zip(files): er._add_file(file_object_id=file[0]) # Retrieve all the objects - ob_idx, ob_rows = np.unique(df[('objects', 'objects_idx')], return_index=True) + ob_idx, ob_rows = np.unique(df[("objects", "objects_idx")], return_index=True) # Sort objects based on their index ob_order = np.argsort(ob_idx) ob_idx = ob_idx[ob_order] ob_rows = ob_rows[ob_order] # Check that objects are consecutively numbered - check_idx(idx_arr=ob_idx, name='objects_idx') + check_idx(idx_arr=ob_idx, name="objects_idx") # Add the objects to the Object table - ob_files = df[('objects', 'files_idx')].iloc[ob_rows] - ob_ids = df[('objects', 'object_id')].iloc[ob_rows] - ob_types = df[('objects', 'object_type')].iloc[ob_rows] - ob_relpaths = df[('objects', 'relative_path')].iloc[ob_rows] - ob_fields = df[('objects', 'field')].iloc[ob_rows] + ob_files = df[("objects", "files_idx")].iloc[ob_rows] + ob_ids = df[("objects", "object_id")].iloc[ob_rows] + ob_types = df[("objects", "object_type")].iloc[ob_rows] + ob_relpaths = df[("objects", "relative_path")].iloc[ob_rows] + ob_fields = df[("objects", "field")].iloc[ob_rows] for ob in zip(ob_files, ob_ids, ob_types, ob_relpaths, ob_fields): er._add_object(files_idx=ob[0], container=ob[1], object_type=ob[2], relative_path=ob[3], field=ob[4]) # Retrieve all keys - keys_idx, keys_rows = np.unique(df[('keys', 'keys_idx')], return_index=True) + keys_idx, keys_rows = np.unique(df[("keys", "keys_idx")], return_index=True) # Sort keys based on their index keys_order = np.argsort(keys_idx) keys_idx = keys_idx[keys_order] keys_rows = keys_rows[keys_order] # Check that keys are consecutively numbered - check_idx(idx_arr=keys_idx, name='keys_idx') + check_idx(idx_arr=keys_idx, name="keys_idx") # Add the keys to the Keys table - keys_key = df[('keys', 'key')].iloc[keys_rows] + keys_key = df[("keys", "key")].iloc[keys_rows] all_added_keys = [er._add_key(k) for k in keys_key] # Add all the object keys to the ObjectKeys table. A single key may be assigned to multiple # objects. As such it is not sufficient to iterate over the unique ob_rows with the unique # objects, but we need to find all unique (objects_idx, keys_idx) combinations. - ob_keys_idx = np.unique(df[[('objects', 'objects_idx'), ('keys', 'keys_idx')]], axis=0) + ob_keys_idx = np.unique(df[[("objects", "objects_idx"), ("keys", "keys_idx")]], axis=0) for obk in ob_keys_idx: er._add_object_key(obj=obk[0], key=obk[1]) # Retrieve all entities - entities_idx, entities_rows = np.unique(df[('entities', 'entities_idx')], return_index=True) + entities_idx, entities_rows = np.unique(df[("entities", "entities_idx")], return_index=True) # Sort entities based on their index entities_order = np.argsort(entities_idx) entities_idx = entities_idx[entities_order] entities_rows = entities_rows[entities_order] # Check that entities are consecutively numbered - check_idx(idx_arr=entities_idx, name='entities_idx') + check_idx(idx_arr=entities_idx, name="entities_idx") # Add the entities to the Resources table - entities_id = df[('entities', 'entity_id')].iloc[entities_rows] - entities_uri = df[('entities', 'entity_uri')].iloc[entities_rows] - entities_keys = np.array(all_added_keys)[df[('keys', 'keys_idx')].iloc[entities_rows]] + entities_id = df[("entities", "entity_id")].iloc[entities_rows] + entities_uri = df[("entities", "entity_uri")].iloc[entities_rows] + entities_keys = np.array(all_added_keys)[df[("keys", "keys_idx")].iloc[entities_rows]] for e in zip(entities_keys, entities_id, entities_uri): er._add_entity(key=e[0], entity_id=e[1], entity_uri=e[2]) # Return the reconstructed ExternalResources diff --git a/src/hdmf/common/sparse.py b/src/hdmf/common/sparse.py index db38d12e8..de0a147a1 100644 --- a/src/hdmf/common/sparse.py +++ b/src/hdmf/common/sparse.py @@ -1,22 +1,49 @@ import scipy.sparse as sps -from . import register_class + from ..container import Container -from ..utils import docval, popargs, to_uint_array, get_data_shape, AllowPositional +from ..utils import AllowPositional, docval, get_data_shape, popargs, to_uint_array +from . import register_class -@register_class('CSRMatrix') +@register_class("CSRMatrix") class CSRMatrix(Container): - - @docval({'name': 'data', 'type': (sps.csr_matrix, 'array_data'), - 'doc': 'the data to use for this CSRMatrix or CSR data array.' - 'If passing CSR data array, *indices*, *indptr*, and *shape* must also be provided'}, - {'name': 'indices', 'type': 'array_data', 'doc': 'CSR index array', 'default': None}, - {'name': 'indptr', 'type': 'array_data', 'doc': 'CSR index pointer array', 'default': None}, - {'name': 'shape', 'type': 'array_data', 'doc': 'the shape of the matrix', 'default': None}, - {'name': 'name', 'type': str, 'doc': 'the name to use for this when storing', 'default': 'csr_matrix'}, - allow_positional=AllowPositional.WARNING) + @docval( + { + "name": "data", + "type": (sps.csr_matrix, "array_data"), + "doc": ( + "the data to use for this CSRMatrix or CSR data array.If passing CSR" + " data array, *indices*, *indptr*, and *shape* must also be provided" + ), + }, + { + "name": "indices", + "type": "array_data", + "doc": "CSR index array", + "default": None, + }, + { + "name": "indptr", + "type": "array_data", + "doc": "CSR index pointer array", + "default": None, + }, + { + "name": "shape", + "type": "array_data", + "doc": "the shape of the matrix", + "default": None, + }, + { + "name": "name", + "type": str, + "doc": "the name to use for this when storing", + "default": "csr_matrix", + }, + allow_positional=AllowPositional.WARNING, + ) def __init__(self, **kwargs): - data, indices, indptr, shape = popargs('data', 'indices', 'indptr', 'shape', kwargs) + data, indices, indptr, shape = popargs("data", "indices", "indptr", "shape", kwargs) super().__init__(**kwargs) if not isinstance(data, sps.csr_matrix): temp_shape = get_data_shape(data) @@ -26,9 +53,9 @@ def __init__(self, **kwargs): elif temp_ndim == 1: if any(_ is None for _ in (indptr, indices, shape)): raise ValueError("Must specify 'indptr', 'indices', and 'shape' arguments when passing data array.") - indptr = self.__check_arr(indptr, 'indptr') - indices = self.__check_arr(indices, 'indices') - shape = self.__check_arr(shape, 'shape') + indptr = self.__check_arr(indptr, "indptr") + indices = self.__check_arr(indices, "indices") + shape = self.__check_arr(shape, "shape") if len(shape) != 2: raise ValueError("'shape' argument must specify two and only two dimensions.") data = sps.csr_matrix((data, indices, indptr), shape=shape) @@ -49,7 +76,7 @@ def __check_arr(ar, arg): def __getattr__(self, val): # NOTE: this provides access to self.data, self.indices, self.indptr, self.shape attr = getattr(self.__data, val) - if val in ('indices', 'indptr', 'shape'): # needed because sps.csr_matrix may contain int arrays for these + if val in ("indices", "indptr", "shape"): # needed because sps.csr_matrix may contain int arrays for these attr = to_uint_array(attr) return attr diff --git a/src/hdmf/common/table.py b/src/hdmf/common/table.py index 9dd1ca267..1759826b1 100644 --- a/src/hdmf/common/table.py +++ b/src/hdmf/common/table.py @@ -3,6 +3,7 @@ the storage and use of dynamic data tables as part of the hdmf-common schema """ +import itertools import re from collections import OrderedDict from typing import NamedTuple, Union @@ -10,15 +11,14 @@ import numpy as np import pandas as pd -import itertools -from . import register_class, EXP_NAMESPACE from ..container import Container, Data -from ..data_utils import DataIO, AbstractDataChunkIterator -from ..utils import docval, getargs, ExtenderMeta, popargs, pystr, AllowPositional +from ..data_utils import AbstractDataChunkIterator, DataIO +from ..utils import AllowPositional, ExtenderMeta, docval, getargs, popargs, pystr +from . import EXP_NAMESPACE, register_class -@register_class('VectorData') +@register_class("VectorData") class VectorData(Data): """ A n-dimensional dataset representing a column of a DynamicTable. @@ -34,20 +34,34 @@ class VectorData(Data): __fields__ = ("description",) - @docval({'name': 'name', 'type': str, 'doc': 'the name of this VectorData'}, - {'name': 'description', 'type': str, 'doc': 'a description for this column'}, - {'name': 'data', 'type': ('array_data', 'data'), - 'doc': 'a dataset where the first dimension is a concatenation of multiple vectors', 'default': list()}, - allow_positional=AllowPositional.WARNING) + @docval( + { + "name": "name", + "type": str, + "doc": "the name of this VectorData", + }, + { + "name": "description", + "type": str, + "doc": "a description for this column", + }, + { + "name": "data", + "type": ("array_data", "data"), + "doc": "a dataset where the first dimension is a concatenation of multiple vectors", + "default": list(), + }, + allow_positional=AllowPositional.WARNING, + ) def __init__(self, **kwargs): - description = popargs('description', kwargs) + description = popargs("description", kwargs) super().__init__(**kwargs) self.description = description - @docval({'name': 'val', 'type': None, 'doc': 'the value to add to this column'}) + @docval({"name": "val", "type": None, "doc": "the value to add to this column"}) def add_row(self, **kwargs): """Append a data value to this VectorData column""" - val = getargs('val', kwargs) + val = getargs("val", kwargs) self.append(val) def get(self, key, **kwargs): @@ -79,7 +93,7 @@ def extend(self, ar, **kwargs): self.add_row(i, **kwargs) -@register_class('VectorIndex') +@register_class("VectorIndex") class VectorIndex(VectorData): """ When paired with a VectorData, this allows for storing arrays of varying @@ -90,15 +104,27 @@ class VectorIndex(VectorData): __fields__ = ("target",) - @docval({'name': 'name', 'type': str, 'doc': 'the name of this VectorIndex'}, - {'name': 'data', 'type': ('array_data', 'data'), - 'doc': 'a 1D dataset containing indexes that apply to VectorData object'}, - {'name': 'target', 'type': VectorData, - 'doc': 'the target dataset that this index applies to'}, - allow_positional=AllowPositional.WARNING) + @docval( + { + "name": "name", + "type": str, + "doc": "the name of this VectorIndex", + }, + { + "name": "data", + "type": ("array_data", "data"), + "doc": "a 1D dataset containing indexes that apply to VectorData object", + }, + { + "name": "target", + "type": VectorData, + "doc": "the target dataset that this index applies to", + }, + allow_positional=AllowPositional.WARNING, + ) def __init__(self, **kwargs): - target = popargs('target', kwargs) - kwargs['description'] = "Index for VectorData '%s'" % target.name + target = popargs("target", kwargs) + kwargs["description"] = "Index for VectorData '%s'" % target.name super().__init__(**kwargs) self.target = target self.__uint = np.uint8 @@ -130,13 +156,15 @@ def __check_precision(self, idx): """ if idx > self.__maxval: while idx > self.__maxval: - nbits = (np.log2(self.__maxval + 1) * 2) # 8->16, 16->32, 32->64 + nbits = np.log2(self.__maxval + 1) * 2 # 8->16, 16->32, 32->64 if nbits == 128: # pragma: no cover - msg = ('Cannot store more than 18446744073709551615 elements in a VectorData. Largest dtype ' - 'allowed for VectorIndex is uint64.') + msg = ( + "Cannot store more than 18446744073709551615 elements in a" + " VectorData. Largest dtype allowed for VectorIndex is uint64." + ) raise ValueError(msg) - self.__maxval = 2 ** nbits - 1 - self.__uint = np.dtype('uint%d' % nbits).type + self.__maxval = 2**nbits - 1 + self.__uint = np.dtype("uint%d" % nbits).type self.__adjust_precision(self.__uint) return self.__uint(idx) @@ -151,7 +179,10 @@ def __adjust_precision(self, uint): # use self._Data__data to work around restriction on resetting self.data self._Data__data = self.data.astype(uint) else: - raise ValueError("cannot adjust precision of type %s to %s", (type(self.data), uint)) + raise ValueError( + "cannot adjust precision of type %s to %s", + (type(self.data), uint), + ) def add_row(self, arg, **kwargs): """ @@ -203,28 +234,46 @@ def get(self, arg, **kwargs): return ret -@register_class('ElementIdentifiers') +@register_class("ElementIdentifiers") class ElementIdentifiers(Data): """ Data container with a list of unique identifiers for values within a dataset, e.g. rows of a DynamicTable. """ - @docval({'name': 'name', 'type': str, 'doc': 'the name of this ElementIdentifiers'}, - {'name': 'data', 'type': ('array_data', 'data'), 'doc': 'a 1D dataset containing identifiers', - 'default': list()}, - allow_positional=AllowPositional.WARNING) + @docval( + { + "name": "name", + "type": str, + "doc": "the name of this ElementIdentifiers", + }, + { + "name": "data", + "type": ("array_data", "data"), + "doc": "a 1D dataset containing identifiers", + "default": list(), + }, + allow_positional=AllowPositional.WARNING, + ) def __init__(self, **kwargs): super().__init__(**kwargs) - @docval({'name': 'other', 'type': (Data, np.ndarray, list, tuple, int), - 'doc': 'List of ids to search for in this ElementIdentifer object'}, - rtype=np.ndarray, - returns='Array with the list of indices where the elements in the list where found.' - 'Note, the elements in the returned list are ordered in increasing index' - 'of the found elements, rather than in the order in which the elements' - 'where given for the search. Also the length of the result may be different from the length' - 'of the input array. E.g., if our ids are [1,2,3] and we are search for [3,1,5] the ' - 'result would be [0,2] and NOT [2,0,None]') + @docval( + { + "name": "other", + "type": (Data, np.ndarray, list, tuple, int), + "doc": "List of ids to search for in this ElementIdentifer object", + }, + rtype=np.ndarray, + returns=( + "Array with the list of indices where the elements in the list where" + " found.Note, the elements in the returned list are ordered in increasing" + " indexof the found elements, rather than in the order in which the" + " elementswhere given for the search. Also the length of the result may be" + " different from the lengthof the input array. E.g., if our ids are [1,2,3]" + " and we are search for [3,1,5] the result would be [0,2] and NOT" + " [2,0,None]" + ), + ) def __eq__(self, other): """ Given a list of ids return the indices in the ElementIdentifiers array where the indices are found. @@ -237,7 +286,7 @@ def __eq__(self, other): return np.in1d(self.data, search_ids).nonzero()[0] -@register_class('DynamicTable') +@register_class("DynamicTable") class DynamicTable(Container): r""" A column-based table. Columns are defined by the argument *columns*. This argument @@ -257,10 +306,10 @@ class DynamicTable(Container): """ __fields__ = ( - {'name': 'id', 'child': True}, - {'name': 'columns', 'child': True}, - 'colnames', - 'description' + {"name": "id", "child": True}, + {"name": "columns", "child": True}, + "colnames", + "description", ) __columns__ = tuple() @@ -277,23 +326,49 @@ def __gather_columns(cls, name, bases, classdict): msg = "'__columns__' must be of type tuple, found %s" % type(cls.__columns__) raise TypeError(msg) - if (len(bases) and 'DynamicTable' in globals() and issubclass(bases[-1], Container) - and bases[-1].__columns__ is not cls.__columns__): + if ( + len(bases) + and "DynamicTable" in globals() + and issubclass(bases[-1], Container) + and bases[-1].__columns__ is not cls.__columns__ + ): new_columns = list(cls.__columns__) new_columns[0:0] = bases[-1].__columns__ # prepend superclass columns to new_columns cls.__columns__ = tuple(new_columns) - @docval({'name': 'name', 'type': str, 'doc': 'the name of this table'}, # noqa: C901 - {'name': 'description', 'type': str, 'doc': 'a description of what is in this table'}, - {'name': 'id', 'type': ('array_data', 'data', ElementIdentifiers), 'doc': 'the identifiers for this table', - 'default': None}, - {'name': 'columns', 'type': (tuple, list), 'doc': 'the columns in this table', 'default': None}, - {'name': 'colnames', 'type': 'array_data', - 'doc': 'the ordered names of the columns in this table. columns must also be provided.', - 'default': None}, - allow_positional=AllowPositional.WARNING) + @docval( + { + "name": "name", + "type": str, + "doc": "the name of this table", + }, + { + "name": "description", + "type": str, + "doc": "a description of what is in this table", + }, + { + "name": "id", + "type": ("array_data", "data", ElementIdentifiers), + "doc": "the identifiers for this table", + "default": None, + }, + { + "name": "columns", + "type": (tuple, list), + "doc": "the columns in this table", + "default": None, + }, + { + "name": "colnames", + "type": "array_data", + "doc": "the ordered names of the columns in this table. columns must also be provided.", + "default": None, + }, + allow_positional=AllowPositional.WARNING, + ) def __init__(self, **kwargs): # noqa: C901 - id, columns, desc, colnames = popargs('id', 'columns', 'description', 'colnames', kwargs) + id, columns, desc, colnames = popargs("id", "columns", "description", "colnames", kwargs) super().__init__(**kwargs) self.description = desc @@ -305,9 +380,9 @@ def __init__(self, **kwargs): # noqa: C901 # Here, we figure out what to do for that if id is not None: if not isinstance(id, ElementIdentifiers): - id = ElementIdentifiers(name='id', data=id) + id = ElementIdentifiers(name="id", data=id) else: - id = ElementIdentifiers(name='id') + id = ElementIdentifiers(name="id") if columns is not None and len(columns) > 0: # If columns have been passed in, check them over and process accordingly @@ -406,8 +481,8 @@ def __init__(self, **kwargs): # noqa: C901 elif isinstance(col, EnumData): # EnumData is the indexing column, so it should go first if col.name not in indices: - indices[col.name] = [col] # EnumData is the indexing object - col_dict[col.name] = col.elements # EnumData.elements is the column with values + indices[col.name] = [col] # EnumData is the indexing object + col_dict[col.name] = col.elements # EnumData.elements is the column with values else: if col.name in indices: continue @@ -443,7 +518,7 @@ def __init__(self, **kwargs): # noqa: C901 col_dict[curr_col.target.name] = col if not hasattr(self, curr_col.target.name): self.__set_table_attr(curr_col.target) - else: # this is a regular VectorData or EnumData + else: # this is a regular VectorData or EnumData # if we added this column using its index, ignore this column if col.name in col_dict: continue @@ -459,14 +534,24 @@ def __init__(self, **kwargs): # noqa: C901 def __set_table_attr(self, col): if hasattr(self, col.name) and col.name not in self.__uninit_cols: - msg = ("An attribute '%s' already exists on %s '%s' so this column cannot be accessed as an attribute, " - "e.g., table.%s; it can only be accessed using other methods, e.g., table['%s']." - % (col.name, self.__class__.__name__, self.name, col.name, col.name)) + msg = ( + "An attribute '%s' already exists on %s '%s' so this column cannot be" + " accessed as an attribute, e.g., table.%s; it can only be accessed" + " using other methods, e.g., table['%s']." + % (col.name, self.__class__.__name__, self.name, col.name, col.name) + ) warn(msg) else: setattr(self, col.name, col) - __reserved_colspec_keys = ['name', 'description', 'index', 'table', 'required', 'class'] + __reserved_colspec_keys = [ + "name", + "description", + "index", + "table", + "required", + "class", + ] def _init_class_columns(self): """ @@ -474,36 +559,37 @@ def _init_class_columns(self): Optional columns are not tracked but not added. """ for col in self.__columns__: - if col['name'] not in self.__colids: # if column has not been added in __init__ - if col.get('required', False): - self.add_column(name=col['name'], - description=col['description'], - index=col.get('index', False), - table=col.get('table', False), - col_cls=col.get('class', VectorData), - # Pass through extra kwargs for add_column that subclasses may have added - **{k: col[k] for k in col.keys() - if k not in DynamicTable.__reserved_colspec_keys}) + if col["name"] not in self.__colids: # if column has not been added in __init__ + if col.get("required", False): + self.add_column( + name=col["name"], + description=col["description"], + index=col.get("index", False), + table=col.get("table", False), + col_cls=col.get("class", VectorData), + # Pass through extra kwargs for add_column that subclasses may have added + **{k: col[k] for k in col.keys() if k not in DynamicTable.__reserved_colspec_keys}, + ) else: # track the not yet initialized optional predefined columns - self.__uninit_cols[col['name']] = col + self.__uninit_cols[col["name"]] = col # set the table attributes for not yet init optional predefined columns - setattr(self, col['name'], None) - index = col.get('index', False) + setattr(self, col["name"], None) + index = col.get("index", False) if index is not False: if index is True: index = 1 if isinstance(index, int): assert index > 0, ValueError("integer index value must be greater than 0") - index_name = col['name'] + index_name = col["name"] for i in range(index): - index_name = index_name + '_index' + index_name = index_name + "_index" self.__uninit_cols[index_name] = col setattr(self, index_name, None) - if col.get('enum', False): - self.__uninit_cols[col['name'] + '_elements'] = col - setattr(self, col['name'] + '_elements', None) + if col.get("enum", False): + self.__uninit_cols[col["name"] + "_elements"] = col + setattr(self, col["name"] + "_elements", None) @staticmethod def __build_columns(columns, df=None): @@ -512,17 +598,20 @@ def __build_columns(columns, df=None): """ tmp = list() for d in columns: - name = d['name'] - desc = d.get('description', 'no description') - col_cls = d.get('class', VectorData) + name = d["name"] + desc = d.get("description", "no description") + col_cls = d.get("class", VectorData) data = None if df is not None: data = list(df[name].values) - index = d.get('index', False) + index = d.get("index", False) if index is not False: if isinstance(index, int) and index > 1: - raise ValueError('Creating nested index columns using this method is not yet supported. Use ' - 'add_column or define the columns using __columns__ instead.') + raise ValueError( + "Creating nested index columns using this method is not yet" + " supported. Use add_column or define the columns using" + " __columns__ instead." + ) index_data = None if data is not None: index_data = [len(data[0])] @@ -538,7 +627,7 @@ def __build_columns(columns, df=None): vindex = VectorIndex(name="%s_index" % name, data=index_data, target=vdata) tmp.append(vindex) tmp.append(vdata) - elif d.get('enum', False): + elif d.get("enum", False): # EnumData is the indexing column, so it should go first if data is not None: elements, data = np.unique(data, return_inverse=True) @@ -551,7 +640,7 @@ def __build_columns(columns, df=None): else: if data is None: data = list() - if d.get('table', False): + if d.get("table", False): col_cls = DynamicTableRegion tmp.append(col_cls(name=name, description=desc, data=data)) return tmp @@ -560,16 +649,32 @@ def __len__(self): """Number of rows in the table""" return len(self.id) - @docval({'name': 'data', 'type': dict, 'doc': 'the data to put in this row', 'default': None}, - {'name': 'id', 'type': int, 'doc': 'the ID for the row', 'default': None}, - {'name': 'enforce_unique_id', 'type': bool, 'doc': 'enforce that the id in the table must be unique', - 'default': False}, - allow_extra=True) + @docval( + { + "name": "data", + "type": dict, + "doc": "the data to put in this row", + "default": None, + }, + { + "name": "id", + "type": int, + "doc": "the ID for the row", + "default": None, + }, + { + "name": "enforce_unique_id", + "type": bool, + "doc": "enforce that the id in the table must be unique", + "default": False, + }, + allow_extra=True, + ) def add_row(self, **kwargs): """ Add a row to the table. If *id* is not provided, it will auto-increment. """ - data, row_id, enforce_unique_id = popargs('data', 'id', 'enforce_unique_id', kwargs) + data, row_id, enforce_unique_id = popargs("data", "id", "enforce_unique_id", kwargs) data = data if data is not None else kwargs extra_columns = set(list(data.keys())) - set(list(self.__colids.keys())) @@ -578,29 +683,33 @@ def add_row(self, **kwargs): # check to see if any of the extra columns just need to be added if extra_columns: for col in self.__columns__: - if col['name'] in extra_columns: - if data[col['name']] is not None: - self.add_column(col['name'], col['description'], - index=col.get('index', False), - table=col.get('table', False), - enum=col.get('enum', False), - col_cls=col.get('class', VectorData), - # Pass through extra keyword arguments for add_column that - # subclasses may have added - **{k: col[k] for k in col.keys() - if k not in DynamicTable.__reserved_colspec_keys}) - extra_columns.remove(col['name']) + if col["name"] in extra_columns: + if data[col["name"]] is not None: + self.add_column( + col["name"], + col["description"], + index=col.get("index", False), + table=col.get("table", False), + enum=col.get("enum", False), + col_cls=col.get("class", VectorData), + # Pass through extra keyword arguments for add_column that + # subclasses may have added + **{k: col[k] for k in col.keys() if k not in DynamicTable.__reserved_colspec_keys}, + ) + extra_columns.remove(col["name"]) if extra_columns or missing_columns: raise ValueError( - '\n'.join([ - 'row data keys don\'t match available columns', - 'you supplied {} extra keys: {}'.format(len(extra_columns), extra_columns), - 'and were missing {} keys: {}'.format(len(missing_columns), missing_columns) - ]) + "\n".join( + [ + "row data keys don't match available columns", + "you supplied {} extra keys: {}".format(len(extra_columns), extra_columns), + "and were missing {} keys: {}".format(len(missing_columns), missing_columns), + ] + ) ) if row_id is None: - row_id = data.pop('id', None) + row_id = data.pop("id", None) if row_id is None: row_id = len(self) if enforce_unique_id: @@ -636,26 +745,60 @@ def __eq__(self, other): return False return self.to_dataframe().equals(other.to_dataframe()) - @docval({'name': 'name', 'type': str, 'doc': 'the name of this VectorData'}, # noqa: C901 - {'name': 'description', 'type': str, 'doc': 'a description for this column'}, - {'name': 'data', 'type': ('array_data', 'data'), - 'doc': 'a dataset where the first dimension is a concatenation of multiple vectors', 'default': list()}, - {'name': 'table', 'type': (bool, 'DynamicTable'), - 'doc': 'whether or not this is a table region or the table the region applies to', 'default': False}, - {'name': 'index', 'type': (bool, VectorIndex, 'array_data', int), - 'doc': ' * ``False`` (default): do not generate a VectorIndex\n\n' - ' * ``True``: generate one empty VectorIndex \n\n' - ' * ``VectorIndex``: Use the supplied VectorIndex \n\n' - ' * array-like of ints: Create a VectorIndex and use these values as the data \n\n' - ' * ``int``: Recursively create `n` VectorIndex objects for a multi-ragged array \n', - 'default': False}, - {'name': 'enum', 'type': (bool, 'array_data'), 'default': False, - 'doc': ('whether or not this column contains data from a fixed set of elements')}, - {'name': 'col_cls', 'type': type, 'default': VectorData, - 'doc': ('class to use to represent the column data. If table=True, this field is ignored and a ' - 'DynamicTableRegion object is used. If enum=True, this field is ignored and a EnumData ' - 'object is used.')}, - allow_extra=True) + @docval( + { + "name": "name", + "type": str, + "doc": "the name of this VectorData", + }, + { + "name": "description", + "type": str, + "doc": "a description for this column", + }, + { + "name": "data", + "type": ("array_data", "data"), + "doc": "a dataset where the first dimension is a concatenation of multiple vectors", + "default": list(), + }, + { + "name": "table", + "type": (bool, "DynamicTable"), + "doc": "whether or not this is a table region or the table the region applies to", + "default": False, + }, + { + "name": "index", + "type": (bool, VectorIndex, "array_data", int), + "doc": ( + " * ``False`` (default): do not generate a VectorIndex\n\n " + " * ``True``: generate one empty VectorIndex \n\n *" + " ``VectorIndex``: Use the supplied VectorIndex \n\n *" + " array-like of ints: Create a VectorIndex and use these values as the" + " data \n\n * ``int``: Recursively create `n` VectorIndex" + " objects for a multi-ragged array \n" + ), + "default": False, + }, + { + "name": "enum", + "type": (bool, "array_data"), + "default": False, + "doc": "whether or not this column contains data from a fixed set of elements", + }, + { + "name": "col_cls", + "type": type, + "default": VectorData, + "doc": ( + "class to use to represent the column data. If table=True, this field" + " is ignored and a DynamicTableRegion object is used. If enum=True," + " this field is ignored and a EnumData object is used." + ), + }, + allow_extra=True, + ) def add_column(self, **kwargs): # noqa: C901 """ Add a column to this table. @@ -666,12 +809,18 @@ def add_column(self, **kwargs): # noqa: C901 :raises ValueError: if the column has already been added to the table """ - name, data = getargs('name', 'data', kwargs) - index, table, enum, col_cls = popargs('index', 'table', 'enum', 'col_cls', kwargs) + name, data = getargs("name", "data", kwargs) + index, table, enum, col_cls = popargs("index", "table", "enum", "col_cls", kwargs) if isinstance(index, VectorIndex): - warn("Passing a VectorIndex in for index may lead to unexpected behavior. This functionality will be " - "deprecated in a future version of HDMF.", FutureWarning) + warn( + ( + "Passing a VectorIndex in for index may lead to unexpected" + " behavior. This functionality will be deprecated in a future" + " version of HDMF." + ), + FutureWarning, + ) if name in self.__colids: # column has already been added msg = "column '%s' already exists in %s '%s'" % (name, self.__class__.__name__, self.name) @@ -681,55 +830,62 @@ def add_column(self, **kwargs): # noqa: C901 # check the given values against the predefined optional column spec. if they do not match, raise a warning # and ignore the given arguments. users should not be able to override these values table_bool = table or not isinstance(table, bool) - spec_table = self.__uninit_cols[name].get('table', False) + spec_table = self.__uninit_cols[name].get("table", False) if table_bool != spec_table: - msg = ("Column '%s' is predefined in %s with table=%s which does not match the entered " - "table argument. The predefined table spec will be ignored. " - "Please ensure the new column complies with the spec. " - "This will raise an error in a future version of HDMF." - % (name, self.__class__.__name__, spec_table)) + msg = ( + "Column '%s' is predefined in %s with table=%s which does not match" + " the entered table argument. The predefined table spec will be" + " ignored. Please ensure the new column complies with the spec." + " This will raise an error in a future version of HDMF." + % (name, self.__class__.__name__, spec_table) + ) warn(msg) index_bool = index or not isinstance(index, bool) - spec_index = self.__uninit_cols[name].get('index', False) + spec_index = self.__uninit_cols[name].get("index", False) if index_bool != spec_index: - msg = ("Column '%s' is predefined in %s with index=%s which does not match the entered " - "index argument. The predefined index spec will be ignored. " - "Please ensure the new column complies with the spec. " - "This will raise an error in a future version of HDMF." - % (name, self.__class__.__name__, spec_index)) + msg = ( + "Column '%s' is predefined in %s with index=%s which does not match" + " the entered index argument. The predefined index spec will be" + " ignored. Please ensure the new column complies with the spec." + " This will raise an error in a future version of HDMF." + % (name, self.__class__.__name__, spec_index) + ) warn(msg) - spec_col_cls = self.__uninit_cols[name].get('class', VectorData) + spec_col_cls = self.__uninit_cols[name].get("class", VectorData) if col_cls != spec_col_cls: - msg = ("Column '%s' is predefined in %s with class=%s which does not match the entered " - "col_cls argument. The predefined class spec will be ignored. " - "Please ensure the new column complies with the spec. " - "This will raise an error in a future version of HDMF." - % (name, self.__class__.__name__, spec_col_cls)) + msg = ( + "Column '%s' is predefined in %s with class=%s which does not match" + " the entered col_cls argument. The predefined class spec will be" + " ignored. Please ensure the new column complies with the spec." + " This will raise an error in a future version of HDMF." + % (name, self.__class__.__name__, spec_col_cls) + ) warn(msg) ckwargs = dict(kwargs) # Add table if it's been specified if table and enum: - raise ValueError("column '%s' cannot be both a table region " - "and come from an enumerable set of elements" % name) + raise ValueError( + "column '%s' cannot be both a table region and come from an enumerable set of elements" % name + ) if table is not False: col_cls = DynamicTableRegion if isinstance(table, DynamicTable): - ckwargs['table'] = table + ckwargs["table"] = table if enum is not False: col_cls = EnumData if isinstance(enum, (list, tuple, np.ndarray, VectorData)): - ckwargs['elements'] = enum + ckwargs["elements"] = enum # If the user provided a list of lists that needs to be indexed, then we now need to flatten the data # We can only create the index actual VectorIndex once we have the VectorData column so we compute # the index and flatten the data here and then create the VectorIndex later from create_vector_index # once we have created the column create_vector_index = None - if ckwargs.get('data', None) is not None: + if ckwargs.get("data", None) is not None: # Check that we are asked to create an index if (isinstance(index, bool) or isinstance(index, int)) and index > 0 and len(data) > 0: # Iteratively flatten the data we use for the column based on the depth of the index to generate. @@ -740,17 +896,23 @@ def add_column(self, **kwargs): # noqa: C901 try: create_vector_index.append(np.cumsum([len(c) for c in flatten_data]).tolist()) except TypeError as e: - raise ValueError("Cannot automatically construct VectorIndex for nested array. " - "Invalid data array element found.") from e + raise ValueError( + "Cannot automatically construct VectorIndex for nested" + " array. Invalid data array element found." + ) from e flatten_data = list(itertools.chain.from_iterable(flatten_data)) # if our data still is an array (e.g., a list or numpy array) then warn that the index parameter # may be incorrect. if len(flatten_data) > 0 and isinstance(flatten_data[0], (np.ndarray, list, tuple)): - raise ValueError("Cannot automatically construct VectorIndex for nested array. " - "Column data contains arrays as cell values. Please check the 'data' and 'index' " - "parameters. 'index=%s' may be too small for the given data." % str(index)) + raise ValueError( + "Cannot automatically construct VectorIndex for nested array." + " Column data contains arrays as cell values. Please check the" + " 'data' and 'index' parameters. 'index=%s' may be too small" + " for the given data." + % str(index) + ) # overwrite the data to be used for the VectorData column with the flattened data - ckwargs['data'] = flatten_data + ckwargs["data"] = flatten_data # Create the VectorData column col = col_cls(**ckwargs) @@ -776,7 +938,11 @@ def add_column(self, **kwargs): # noqa: C901 col_index = VectorIndex(name=name + "_index", data=list(), target=col) # create single-level VectorIndex from the data based on the create_vector_index we computed earlier else: - col_index = VectorIndex(name=name + "_index", data=create_vector_index[0], target=col) + col_index = VectorIndex( + name=name + "_index", + data=create_vector_index[0], + target=col, + ) # add the column with the index self.__add_column_index_helper(col_index) elif isinstance(index, int): @@ -796,7 +962,11 @@ def add_column(self, **kwargs): # noqa: C901 index_name = name for i in range(index): index_name = index_name + "_index" - col_index = VectorIndex(name=index_name, data=create_vector_index[-(i+1)], target=col) + col_index = VectorIndex( + name=index_name, + data=create_vector_index[-(i + 1)], + target=col, + ) self.__add_column_index_helper(col_index) if i < index - 1: columns.insert(0, col_index) @@ -811,8 +981,8 @@ def add_column(self, **kwargs): # noqa: C901 if len(col) != len(self.id): raise ValueError("column must have the same number of rows as 'id'") self.__colids[name] = len(self.__df_cols) - self.fields['colnames'] = tuple(list(self.colnames) + [name]) - self.fields['columns'] = tuple(list(self.columns) + columns) + self.fields["colnames"] = tuple(list(self.colnames) + [name]) + self.fields["columns"] = tuple(list(self.columns) + columns) self.__df_cols.append(col) def __add_column_index_helper(self, col_index): @@ -824,9 +994,23 @@ def __add_column_index_helper(self, col_index): if col_index in self.__uninit_cols: self.__uninit_cols.pop(col_index) - @docval({'name': 'name', 'type': str, 'doc': 'the name of the DynamicTableRegion object'}, - {'name': 'region', 'type': (slice, list, tuple), 'doc': 'the indices of the table'}, - {'name': 'description', 'type': str, 'doc': 'a brief description of what the region is'}) + @docval( + { + "name": "name", + "type": str, + "doc": "the name of the DynamicTableRegion object", + }, + { + "name": "region", + "type": (slice, list, tuple), + "doc": "the indices of the table", + }, + { + "name": "description", + "type": str, + "doc": "a brief description of what the region is", + }, + ) def create_region(self, **kwargs): """ Create a DynamicTableRegion selecting a region (i.e., rows) in this DynamicTable. @@ -834,20 +1018,20 @@ def create_region(self, **kwargs): :raises: IndexError if the provided region contains invalid indices """ - region = getargs('region', kwargs) + region = getargs("region", kwargs) if isinstance(region, slice): if (region.start is not None and region.start < 0) or (region.stop is not None and region.stop > len(self)): - msg = 'region slice %s is out of range for this DynamicTable of length %d' % (str(region), len(self)) + msg = "region slice %s is out of range for this DynamicTable of length %d" % (str(region), len(self)) raise IndexError(msg) region = list(range(*region.indices(len(self)))) else: for idx in region: if idx < 0 or idx >= len(self): - raise IndexError('The index ' + str(idx) + - ' is out of range for this DynamicTable of length ' - + str(len(self))) - desc = getargs('description', kwargs) - name = getargs('name', kwargs) + raise IndexError( + "The index " + str(idx) + " is out of range for this DynamicTable of length " + str(len(self)) + ) + desc = getargs("description", kwargs) + name = getargs("name", kwargs) return DynamicTableRegion(name=name, data=region, description=desc, table=self) def __getitem__(self, key): @@ -888,7 +1072,7 @@ def get(self, key, default=None, df=True, index=True, **kwargs): ret = None if not df and not index: # returning nested lists of lists for DTRs and ragged DTRs is complicated and not yet supported - raise ValueError('DynamicTable.get() with df=False and index=False is not yet supported.') + raise ValueError("DynamicTable.get() with df=False and index=False is not yet supported.") if isinstance(key, tuple): # index by row and column --> return specific cell arg1 = key[0] @@ -898,7 +1082,7 @@ def get(self, key, default=None, df=True, index=True, **kwargs): ret = self.__df_cols[arg2][arg1] elif isinstance(key, str): # index by one string --> return column - if key == 'id': + if key == "id": return self.id elif key in self.__colids: ret = self.__df_cols[self.__colids[key]] @@ -937,13 +1121,15 @@ def __get_selection_as_dict(self, arg, df, index, exclude=None, **kwargs): ret = OrderedDict() try: # index with a python slice or single int to select one or multiple rows - ret['id'] = self.id[arg] + ret["id"] = self.id[arg] for name in self.colnames: if name in exclude: continue col = self.__df_cols[self.__colids[name]] - if index and (isinstance(col, DynamicTableRegion) or - (isinstance(col, VectorIndex) and isinstance(col.target, DynamicTableRegion))): + if index and ( + isinstance(col, DynamicTableRegion) + or (isinstance(col, VectorIndex) and isinstance(col.target, DynamicTableRegion)) + ): # return indices (in list, array, etc.) for DTR and ragged DTR ret[name] = col.get(arg, df=False, index=True, **kwargs) else: @@ -956,20 +1142,31 @@ def __get_selection_as_dict(self, arg, df, index, exclude=None, **kwargs): # in h5py 3+, this became an IndexError x = re.match(r"^Index \((.*)\) out of range \(.*\)$", str(ve)) if x: - msg = ("Row index %s out of range for %s '%s' (length %d)." - % (x.groups()[0], self.__class__.__name__, self.name, len(self))) + msg = "Row index %s out of range for %s '%s' (length %d)." % ( + x.groups()[0], + self.__class__.__name__, + self.name, + len(self), + ) raise IndexError(msg) from ve else: # pragma: no cover raise ve except IndexError as ie: x = re.match(r"^Index \((.*)\) out of range for \(.*\)$", str(ie)) if x: - msg = ("Row index %s out of range for %s '%s' (length %d)." - % (x.groups()[0], self.__class__.__name__, self.name, len(self))) + msg = "Row index %s out of range for %s '%s' (length %d)." % ( + x.groups()[0], + self.__class__.__name__, + self.name, + len(self), + ) raise IndexError(msg) - elif str(ie) == 'list index out of range': - msg = ("Row index out of range for %s '%s' (length %d)." - % (self.__class__.__name__, self.name, len(self))) + elif str(ie) == "list index out of range": + msg = "Row index out of range for %s '%s' (length %d)." % ( + self.__class__.__name__, + self.name, + len(self), + ) raise IndexError(msg) from ie else: # pragma: no cover raise ie @@ -982,7 +1179,7 @@ def __get_selection_as_df_single_row(self, coldata): :param coldata: dict mapping column names to values (list/arrays or dataframes) :type coldata: dict """ - id_index_orig = coldata.pop('id') + id_index_orig = coldata.pop("id") id_index = [id_index_orig] df_input = OrderedDict() for k in coldata: # for each column @@ -1005,7 +1202,7 @@ def __get_selection_as_df(self, coldata): :param coldata: dict mapping column names to values (list/arrays or dataframes) :type coldata: dict """ - id_index = coldata.pop('id') + id_index = coldata.pop("id") df_input = OrderedDict() for k in coldata: # for each column if isinstance(coldata[k], np.ndarray) and coldata[k].ndim > 1: @@ -1052,11 +1249,19 @@ def has_foreign_columns(self): return True return False - @docval({'name': 'other_tables', 'type': (list, tuple, set), - 'doc': "List of additional tables to consider in the search. Usually this " - "parameter is used for internal purposes, e.g., when we need to " - "consider AlignedDynamicTable", 'default': None}, - allow_extra=False) + @docval( + { + "name": "other_tables", + "type": (list, tuple, set), + "doc": ( + "List of additional tables to consider in the search. Usually this " + "parameter is used for internal purposes, e.g., when we need to " + "consider AlignedDynamicTable" + ), + "default": None, + }, + allow_extra=False, + ) def get_linked_tables(self, **kwargs): """ Get a list of the full list of all tables that are being linked to directly or indirectly @@ -1068,12 +1273,18 @@ def get_linked_tables(self, **kwargs): * 'source_column' : The relevant DynamicTableRegion column in the 'source_table' * 'target_table' : The target DynamicTable; same as source_column.table. """ - link_type = NamedTuple('DynamicTableLink', - [('source_table', DynamicTable), - ('source_column', Union[DynamicTableRegion, VectorIndex]), - ('target_table', DynamicTable)]) - curr_tables = [self, ] # Set of tables - other_tables = getargs('other_tables', kwargs) + link_type = NamedTuple( + "DynamicTableLink", + [ + ("source_table", DynamicTable), + ("source_column", Union[DynamicTableRegion, VectorIndex]), + ("target_table", DynamicTable), + ], + ) + curr_tables = [ + self, + ] # Set of tables + other_tables = getargs("other_tables", kwargs) if other_tables is not None: curr_tables += other_tables curr_index = 0 @@ -1081,9 +1292,13 @@ def get_linked_tables(self, **kwargs): while curr_index < len(curr_tables): for col_index, col in enumerate(curr_tables[curr_index].columns): if isinstance(col, DynamicTableRegion): - foreign_cols.append(link_type(source_table=curr_tables[curr_index], - source_column=col, - target_table=col.table)) + foreign_cols.append( + link_type( + source_table=curr_tables[curr_index], + source_column=col, + target_table=col.table, + ) + ) curr_table_visited = False for t in curr_tables: if t is col.table: @@ -1093,13 +1308,23 @@ def get_linked_tables(self, **kwargs): curr_index += 1 return foreign_cols - @docval({'name': 'exclude', 'type': set, 'doc': 'Set of column names to exclude from the dataframe', - 'default': None}, - {'name': 'index', 'type': bool, - 'doc': ('Whether to return indices for a DynamicTableRegion column. If False, nested dataframes will be ' - 'returned.'), - 'default': False} - ) + @docval( + { + "name": "exclude", + "type": set, + "doc": "Set of column names to exclude from the dataframe", + "default": None, + }, + { + "name": "index", + "type": bool, + "doc": ( + "Whether to return indices for a DynamicTableRegion column. If False," + " nested dataframes will be returned." + ), + "default": False, + }, + ) def to_dataframe(self, **kwargs): """ Produce a pandas DataFrame containing this table's data. @@ -1115,30 +1340,38 @@ def to_dataframe(self, **kwargs): @classmethod @docval( - {'name': 'df', 'type': pd.DataFrame, 'doc': 'source DataFrame'}, - {'name': 'name', 'type': str, 'doc': 'the name of this table'}, { - 'name': 'index_column', - 'type': str, - 'doc': 'if provided, this column will become the table\'s index', - 'default': None + "name": "df", + "type": pd.DataFrame, + "doc": "source DataFrame", + }, + { + "name": "name", + "type": str, + "doc": "the name of this table", + }, + { + "name": "index_column", + "type": str, + "doc": "if provided, this column will become the table's index", + "default": None, }, { - 'name': 'table_description', - 'type': str, - 'doc': 'a description of what is in the resulting table', - 'default': '' + "name": "table_description", + "type": str, + "doc": "a description of what is in the resulting table", + "default": "", }, { - 'name': 'columns', - 'type': (list, tuple), - 'doc': 'a list/tuple of dictionaries specifying columns in the table', - 'default': None + "name": "columns", + "type": (list, tuple), + "doc": "a list/tuple of dictionaries specifying columns in the table", + "default": None, }, - allow_extra=True + allow_extra=True, ) def from_dataframe(cls, **kwargs): - ''' + """ Construct an instance of DynamicTable (or a subclass) from a pandas DataFrame. The columns of the resulting table are defined by the columns of the @@ -1148,26 +1381,26 @@ def from_dataframe(cls, **kwargs): dictionaries containing the name and description of the column- to help others understand the contents of your table. See :py:class:`~hdmf.common.table.DynamicTable` for more details on *columns*. - ''' + """ - columns = kwargs.pop('columns') - df = kwargs.pop('df') - name = kwargs.pop('name') - index_column = kwargs.pop('index_column') - table_description = kwargs.pop('table_description') - column_descriptions = kwargs.pop('column_descriptions', dict()) + columns = kwargs.pop("columns") + df = kwargs.pop("df") + name = kwargs.pop("name") + index_column = kwargs.pop("index_column") + table_description = kwargs.pop("table_description") + column_descriptions = kwargs.pop("column_descriptions", dict()) supplied_columns = dict() if columns: - supplied_columns = {x['name']: x for x in columns} + supplied_columns = {x["name"]: x for x in columns} - class_cols = {x['name']: x for x in cls.__columns__} - required_cols = set(x['name'] for x in cls.__columns__ if 'required' in x and x['required']) + class_cols = {x["name"]: x for x in cls.__columns__} + required_cols = set(x["name"] for x in cls.__columns__ if "required" in x and x["required"]) df_cols = df.columns if required_cols - set(df_cols): - raise ValueError('missing required cols: ' + str(required_cols - set(df_cols))) + raise ValueError("missing required cols: " + str(required_cols - set(df_cols))) if set(supplied_columns.keys()) - set(df_cols): - raise ValueError('cols specified but not provided: ' + str(set(supplied_columns.keys()) - set(df_cols))) + raise ValueError("cols specified but not provided: " + str(set(supplied_columns.keys()) - set(df_cols))) columns = [] for col_name in df_cols: if col_name in class_cols: @@ -1175,9 +1408,13 @@ def from_dataframe(cls, **kwargs): elif col_name in supplied_columns: columns.append(supplied_columns[col_name]) else: - columns.append({'name': col_name, - 'description': column_descriptions.get(col_name, 'no description')}) - if hasattr(df[col_name].iloc[0], '__len__') and not isinstance(df[col_name].iloc[0], str): + columns.append( + { + "name": col_name, + "description": column_descriptions.get(col_name, "no description"), + } + ) + if hasattr(df[col_name].iloc[0], "__len__") and not isinstance(df[col_name].iloc[0], str): lengths = [len(x) for x in df[col_name]] if not lengths[1:] == lengths[:-1]: columns[-1].update(index=True) @@ -1185,24 +1422,35 @@ def from_dataframe(cls, **kwargs): if index_column is not None: ids = ElementIdentifiers(name=index_column, data=df[index_column].values.tolist()) else: - index_name = df.index.name if df.index.name is not None else 'id' + index_name = df.index.name if df.index.name is not None else "id" ids = ElementIdentifiers(name=index_name, data=df.index.values.tolist()) columns = cls.__build_columns(columns, df=df) - return cls(name=name, id=ids, columns=columns, description=table_description, **kwargs) + return cls( + name=name, + id=ids, + columns=columns, + description=table_description, + **kwargs, + ) def copy(self): """ Return a copy of this DynamicTable. This is useful for linking. """ - kwargs = dict(name=self.name, id=self.id, columns=self.columns, description=self.description, - colnames=self.colnames) + kwargs = dict( + name=self.name, + id=self.id, + columns=self.columns, + description=self.description, + colnames=self.colnames, + ) return self.__class__(**kwargs) -@register_class('DynamicTableRegion') +@register_class("DynamicTableRegion") class DynamicTableRegion(VectorData): """ DynamicTableRegion provides a link from one table to an index or region of another. The `table` @@ -1214,26 +1462,41 @@ class DynamicTableRegion(VectorData): `DynamicTable` can reference many rows of another `DynamicTable`. """ - __fields__ = ( - 'table', - ) + __fields__ = ("table",) - @docval({'name': 'name', 'type': str, 'doc': 'the name of this VectorData'}, - {'name': 'data', 'type': ('array_data', 'data'), - 'doc': 'a dataset where the first dimension is a concatenation of multiple vectors'}, - {'name': 'description', 'type': str, 'doc': 'a description of what this region represents'}, - {'name': 'table', 'type': DynamicTable, - 'doc': 'the DynamicTable this region applies to', 'default': None}, - allow_positional=AllowPositional.WARNING) + @docval( + { + "name": "name", + "type": str, + "doc": "the name of this VectorData", + }, + { + "name": "data", + "type": ("array_data", "data"), + "doc": "a dataset where the first dimension is a concatenation of multiple vectors", + }, + { + "name": "description", + "type": str, + "doc": "a description of what this region represents", + }, + { + "name": "table", + "type": DynamicTable, + "doc": "the DynamicTable this region applies to", + "default": None, + }, + allow_positional=AllowPositional.WARNING, + ) def __init__(self, **kwargs): - t = popargs('table', kwargs) + t = popargs("table", kwargs) super().__init__(**kwargs) self.table = t @property def table(self): """The DynamicTable this DynamicTableRegion is pointing to""" - return self.fields.get('table') + return self.fields.get("table") @table.setter def table(self, val): @@ -1247,13 +1510,13 @@ def table(self, val): """ if val is None: return - if 'table' in self.fields: + if "table" in self.fields: msg = "can't set attribute 'table' -- already set" raise AttributeError(msg) dat = self.data if isinstance(dat, DataIO): dat = dat.data - self.fields['table'] = val + self.fields["table"] = val def __getitem__(self, arg): return self.get(arg) @@ -1280,7 +1543,7 @@ def get(self, arg, index=False, df=True, **kwargs): """ if not df and not index: # returning nested lists of lists for DTRs and ragged DTRs is complicated and not yet supported - raise ValueError('DynamicTableRegion.get() with df=False and index=False is not yet supported.') + raise ValueError("DynamicTableRegion.get() with df=False and index=False is not yet supported.") # treat the list of indices as data that can be indexed. then pass the # result to the table to get the data if isinstance(arg, tuple): @@ -1291,7 +1554,7 @@ def get(self, arg, index=False, df=True, **kwargs): return self.table[arg] elif np.issubdtype(type(arg), np.integer): if arg >= len(self.data): - raise IndexError('index {} out of bounds for data of length {}'.format(arg, len(self.data))) + raise IndexError("index {} out of bounds for data of length {}".format(arg, len(self.data))) ret = self.data[arg] if not index: ret = self.table.get(ret, df=df, index=index, **kwargs) @@ -1345,7 +1608,7 @@ def _index_lol(self, result, index, lut): elif isinstance(col, np.ndarray): ret.append(np.array([col[lut[i]] for i in index], dtype=col.dtype)) else: - raise ValueError('unrecognized column type: %s. Expected list or np.ndarray' % type(col)) + raise ValueError("unrecognized column type: %s. Expected list or np.ndarray" % type(col)) return ret def to_dataframe(self, **kwargs): @@ -1372,47 +1635,70 @@ def __repr__(self): """ cls = self.__class__ template = "%s %s.%s at 0x%d\n" % (self.name, cls.__module__, cls.__name__, id(self)) - template += " Target table: %s %s.%s at 0x%d\n" % (self.table.name, - self.table.__class__.__module__, - self.table.__class__.__name__, - id(self.table)) + template += " Target table: %s %s.%s at 0x%d\n" % ( + self.table.name, + self.table.__class__.__module__, + self.table.__class__.__name__, + id(self.table), + ) return template def _uint_precision(elements): - """ Calculate the uint precision needed to encode a set of elements """ + """Calculate the uint precision needed to encode a set of elements""" n_elements = elements - if hasattr(elements, '__len__'): + if hasattr(elements, "__len__"): n_elements = len(elements) - return np.dtype('uint%d' % (8 * max(1, int((2 ** np.ceil((np.ceil(np.log2(n_elements)) - 8) / 8)))))).type + return np.dtype("uint%d" % (8 * max(1, int((2 ** np.ceil((np.ceil(np.log2(n_elements)) - 8) / 8)))))).type def _map_elements(uint, elements): - """ Map CV terms to their uint index """ + """Map CV terms to their uint index""" return {t[1]: uint(t[0]) for t in enumerate(elements)} -@register_class('EnumData', EXP_NAMESPACE) +@register_class("EnumData", EXP_NAMESPACE) class EnumData(VectorData): """ A n-dimensional dataset that can contain elements from fixed set of elements. """ - __fields__ = ('elements', ) + __fields__ = ("elements",) - @docval({'name': 'name', 'type': str, 'doc': 'the name of this column'}, - {'name': 'description', 'type': str, 'doc': 'a description for this column'}, - {'name': 'data', 'type': ('array_data', 'data'), - 'doc': 'integers that index into elements for the value of each row', 'default': list()}, - {'name': 'elements', 'type': ('array_data', 'data', VectorData), 'default': list(), - 'doc': 'lookup values for each integer in ``data``'}, - allow_positional=AllowPositional.WARNING) + @docval( + { + "name": "name", + "type": str, + "doc": "the name of this column", + }, + { + "name": "description", + "type": str, + "doc": "a description for this column", + }, + { + "name": "data", + "type": ("array_data", "data"), + "doc": "integers that index into elements for the value of each row", + "default": list(), + }, + { + "name": "elements", + "type": ("array_data", "data", VectorData), + "default": list(), + "doc": "lookup values for each integer in ``data``", + }, + allow_positional=AllowPositional.WARNING, + ) def __init__(self, **kwargs): - elements = popargs('elements', kwargs) + elements = popargs("elements", kwargs) super().__init__(**kwargs) if not isinstance(elements, VectorData): - elements = VectorData(name='%s_elements' % self.name, data=elements, - description='fixed set of elements referenced by %s' % self.name) + elements = VectorData( + name="%s_elements" % self.name, + data=elements, + description="fixed set of elements referenced by %s" % self.name, + ) self.elements = elements if len(self.elements) > 0: self.__uint = _uint_precision(self.elements.data) @@ -1459,7 +1745,7 @@ def _get_helper(self, idx, index=False, join=False, **kwargs): idx = np.asarray(idx) ret = np.asarray(self.elements.get(idx.ravel(), **kwargs)).reshape(idx.shape) if join: - ret = ''.join(ret.ravel()) + ret = "".join(ret.ravel()) else: ret = self.elements.get(idx, **kwargs) return ret @@ -1479,16 +1765,17 @@ def get(self, arg, index=False, join=False, **kwargs): idx = self.data[arg] return self._get_helper(idx, index=index, join=join, **kwargs) - @docval({'name': 'val', 'type': None, 'doc': 'the value to add to this column'}, - {'name': 'index', 'type': bool, 'doc': 'whether or not the value being added is an index', - 'default': False}) + @docval( + {"name": "val", "type": None, "doc": "the value to add to this column"}, + {"name": "index", "type": bool, "doc": "whether or not the value being added is an index", "default": False}, + ) def add_row(self, **kwargs): """Append a data value to this EnumData column If an element is provided for *val* (i.e. *index* is False), the correct index value will be determined. Otherwise, *val* will be added as provided. """ - val, index = getargs('val', 'index', kwargs) + val, index = getargs("val", "index", kwargs) if not index: val = self.__add_term(val) super().append(val) diff --git a/src/hdmf/container.py b/src/hdmf/container.py index 762ebeae1..37713900d 100644 --- a/src/hdmf/container.py +++ b/src/hdmf/container.py @@ -10,7 +10,15 @@ import pandas as pd from .data_utils import DataIO, append_data, extend_data -from .utils import docval, get_docval, getargs, ExtenderMeta, get_data_shape, popargs, LabelledDict +from .utils import ( + ExtenderMeta, + LabelledDict, + docval, + get_data_shape, + get_docval, + getargs, + popargs, +) def _set_exp(cls): @@ -23,8 +31,10 @@ def _exp_warn_msg(cls): pfx = cls if isinstance(cls, type): pfx = cls.__name__ - msg = ('%s is experimental -- it may be removed in the future and ' - 'is not guaranteed to maintain backward compatibility') % pfx + msg = ( + "%s is experimental -- it may be removed in the future and is not guaranteed to maintain backward compatibility" + % pfx + ) return msg @@ -33,13 +43,18 @@ class ExternalResourcesManager: This class manages whether to set/attach an instance of ExternalResources to the subclass. """ - @docval({'name': 'external_resources', 'type': 'ExternalResources', - 'doc': 'The external resources to be used for the container.'},) + @docval( + { + "name": "external_resources", + "type": "ExternalResources", + "doc": "The external resources to be used for the container.", + }, + ) def link_resources(self, **kwargs): """ Method to attach an instance of ExternalResources in order to auto-add terms/references to data. """ - self._external_resources = kwargs['external_resources'] + self._external_resources = kwargs["external_resources"] def get_linked_resources(self): return self._external_resources if hasattr(self, "_external_resources") else None @@ -52,9 +67,9 @@ class AbstractContainer(metaclass=ExtenderMeta): _experimental = False - _fieldsname = '__fields__' + _fieldsname = "__fields__" - _data_type_attr = 'data_type' + _data_type_attr = "data_type" # Subclasses use this class attribute to add properties to autogenerate # Autogenerated properties will store values in self.__field_values @@ -64,7 +79,7 @@ class AbstractContainer(metaclass=ExtenderMeta): # It holds all the values in __fields__ for this class and its parent classes. __fieldsconf = tuple() - _pconf_allowed_keys = {'name', 'doc', 'settable'} + _pconf_allowed_keys = {"name", "doc", "settable"} # Override the _setter factor function, so directives that apply to # Container do not get used on Data @@ -73,9 +88,9 @@ def _setter(cls, field): """ Make a setter function for creating a :py:func:`property` """ - name = field['name'] + name = field["name"] - if not field.get('settable', True): + if not field.get("settable", True): return None def setter(self, val): @@ -93,13 +108,13 @@ def _getter(cls, field): """ Make a getter function for creating a :py:func:`property` """ - doc = field.get('doc') - name = field['name'] + doc = field.get("doc") + name = field["name"] def getter(self): return self.fields.get(name) - setattr(getter, '__doc__', doc) + setattr(getter, "__doc__", doc) return getter @staticmethod @@ -110,18 +125,22 @@ def _check_field_spec(field): """ tmp = field if isinstance(tmp, dict): - if 'name' not in tmp: + if "name" not in tmp: raise ValueError("must specify 'name' if using dict in __fields__") else: - tmp = {'name': tmp} + tmp = {"name": tmp} return tmp @classmethod def _check_field_spec_keys(cls, field_conf): for k in field_conf: if k not in cls._pconf_allowed_keys: - msg = ("Unrecognized key '%s' in %s config '%s' on %s" - % (k, cls._fieldsname, field_conf['name'], cls.__name__)) + msg = "Unrecognized key '%s' in %s config '%s' on %s" % ( + k, + cls._fieldsname, + field_conf["name"], + cls.__name__, + ) raise ValueError(msg) @classmethod @@ -138,10 +157,10 @@ def get_fields_conf(cls): @ExtenderMeta.pre_init def __gather_fields(cls, name, bases, classdict): - ''' + """ This classmethod will be called during class declaration in the metaclass to automatically create setters and getters for fields that need to be exported - ''' + """ fields = cls._get_fields() if not isinstance(fields, tuple): msg = "'%s' must be of type tuple" % cls._fieldsname @@ -152,7 +171,7 @@ def __gather_fields(cls, name, bases, classdict): for f in fields: pconf = cls._check_field_spec(f) cls._check_field_spec_keys(pconf) - fields_dict[pconf['name']] = pconf + fields_dict[pconf["name"]] = pconf all_fields_conf = list(fields_dict.values()) # check whether this class overrides __fields__ @@ -175,20 +194,24 @@ def __gather_fields(cls, name, bases, classdict): base_fields_conf = base_cls.get_fields_conf() # tuple of fields configurations from base class base_fields_conf_to_add = list() for pconf in base_fields_conf: - if pconf['name'] not in fields_to_remove_from_base: + if pconf["name"] not in fields_to_remove_from_base: base_fields_conf_to_add.append(pconf) all_fields_conf[0:0] = base_fields_conf_to_add # create getter and setter if attribute does not already exist # if 'doc' not specified in __fields__, use doc from docval of __init__ - docs = {dv['name']: dv['doc'] for dv in get_docval(cls.__init__)} + docs = {dv["name"]: dv["doc"] for dv in get_docval(cls.__init__)} for field_conf in all_fields_conf: - pname = field_conf['name'] - field_conf.setdefault('doc', docs.get(pname)) + pname = field_conf["name"] + field_conf.setdefault("doc", docs.get(pname)) if not hasattr(cls, pname): - setattr(cls, pname, property(cls._getter(field_conf), cls._setter(field_conf))) + setattr( + cls, + pname, + property(cls._getter(field_conf), cls._setter(field_conf)), + ) - cls._set_fields(tuple(field_conf['name'] for field_conf in all_fields_conf)) + cls._set_fields(tuple(field_conf["name"] for field_conf in all_fields_conf)) cls.__fieldsconf = tuple(all_fields_conf) def __new__(cls, *args, **kwargs): @@ -202,38 +225,38 @@ def __new__(cls, *args, **kwargs): inst = super().__new__(cls) if cls._experimental: warn(_exp_warn_msg(cls)) - inst.__container_source = kwargs.pop('container_source', None) + inst.__container_source = kwargs.pop("container_source", None) inst.__parent = None inst.__children = list() inst.__modified = True - inst.__object_id = kwargs.pop('object_id', str(uuid4())) + inst.__object_id = kwargs.pop("object_id", str(uuid4())) # this variable is being passed in from ObjectMapper.__new_container__ and is # reset to False in that method after the object has been initialized by __init__ - inst._in_construct_mode = kwargs.pop('in_construct_mode', False) - inst.parent = kwargs.pop('parent', None) + inst._in_construct_mode = kwargs.pop("in_construct_mode", False) + inst.parent = kwargs.pop("parent", None) return inst - @docval({'name': 'name', 'type': str, 'doc': 'the name of this container'}) + @docval({"name": "name", "type": str, "doc": "the name of this container"}) def __init__(self, **kwargs): - name = getargs('name', kwargs) - if '/' in name: + name = getargs("name", kwargs) + if "/" in name: raise ValueError("name '" + name + "' cannot contain '/'") self.__name = name self.__field_values = dict() @property def name(self): - ''' + """ The name of this Container - ''' + """ return self.__name - @docval({'name': 'data_type', 'type': str, 'doc': 'the data_type to search for', 'default': None}) + @docval({"name": "data_type", "type": str, "doc": "the data_type to search for", "default": None}) def get_ancestor(self, **kwargs): """ Traverse parent hierarchy and return first instance of the specified data_type """ - data_type = getargs('data_type', kwargs) + data_type = getargs("data_type", kwargs) if data_type is None: return self.parent p = self.parent @@ -245,7 +268,7 @@ def get_ancestor(self, **kwargs): @property def fields(self): - ''' + """ Subclasses use this class attribute to add properties to autogenerate. `fields` allows for lists and for dicts with the keys {'name', 'child', 'required_name', 'doc', 'settable'}. 1. name: The name of the field property @@ -253,7 +276,7 @@ def fields(self): 3. required_name: The name the field property must have such that `name` matches `required_name`. 4. doc: Documentation of the field property 5. settable: If true, a setter function is created so that the field can be changed after creation. - ''' + """ return self.__field_values @property @@ -262,11 +285,17 @@ def object_id(self): self.__object_id = str(uuid4()) return self.__object_id - @docval({'name': 'recurse', 'type': bool, - 'doc': "whether or not to change the object ID of this container's children", 'default': True}) + @docval( + { + "name": "recurse", + "type": bool, + "doc": "whether or not to change the object ID of this container's children", + "default": True, + } + ) def generate_new_id(self, **kwargs): """Changes the object ID of this Container and all of its children to a new UUID string.""" - recurse = getargs('recurse', kwargs) + recurse = getargs("recurse", kwargs) self.__object_id = str(uuid4()) self.set_modified() if recurse: @@ -277,10 +306,11 @@ def generate_new_id(self, **kwargs): def modified(self): return self.__modified - @docval({'name': 'modified', 'type': bool, - 'doc': 'whether or not this Container has been modified', 'default': True}) + @docval( + {"name": "modified", "type": bool, "doc": "whether or not this Container has been modified", "default": True} + ) def set_modified(self, **kwargs): - modified = getargs('modified', kwargs) + modified = getargs("modified", kwargs) self.__modified = modified if modified and isinstance(self.parent, Container): self.parent.set_modified() @@ -289,11 +319,10 @@ def set_modified(self, **kwargs): def children(self): return tuple(self.__children) - @docval({'name': 'child', 'type': 'Container', - 'doc': 'the child Container for this Container', 'default': None}) + @docval({"name": "child", "type": "Container", "doc": "the child Container for this Container", "default": None}) def add_child(self, **kwargs): - warn(DeprecationWarning('add_child is deprecated. Set the parent attribute instead.')) - child = getargs('child', kwargs) + warn(DeprecationWarning("add_child is deprecated. Set the parent attribute instead.")) + child = getargs("child", kwargs) if child is not None: # if child.parent is a Container, then the mismatch between child.parent and parent # is used to make a soft/external link from the parent to a child elsewhere @@ -302,7 +331,7 @@ def add_child(self, **kwargs): # actually add the child to the parent in parent setter child.parent = self else: - warn('Cannot add None as child to a container %s' % self.name) + warn("Cannot add None as child to a container %s" % self.name) @classmethod def type_hierarchy(cls): @@ -310,24 +339,24 @@ def type_hierarchy(cls): @property def container_source(self): - ''' + """ The source of this Container - ''' + """ return self.__container_source @container_source.setter def container_source(self, source): if self.__container_source is not None: - raise Exception('cannot reassign container_source') + raise Exception("cannot reassign container_source") self.__container_source = source @property def parent(self): - ''' + """ The parent Container of this Container - ''' + """ # do it this way because __parent may not exist yet (not set in constructor) - return getattr(self, '_AbstractContainer__parent', None) + return getattr(self, "_AbstractContainer__parent", None) @parent.setter def parent(self, parent_container): @@ -336,8 +365,9 @@ def parent(self, parent_container): if self.parent is not None: if isinstance(self.parent, AbstractContainer): - raise ValueError(('Cannot reassign parent to Container: %s. ' - 'Parent is already: %s.' % (repr(self), repr(self.parent)))) + raise ValueError( + "Cannot reassign parent to Container: %s. Parent is already: %s." % (repr(self), repr(self.parent)) + ) else: if parent_container is None: raise ValueError("Got None for parent of '%s' - cannot overwrite Proxy with NoneType" % repr(self)) @@ -358,10 +388,12 @@ def parent(self, parent_container): def _remove_child(self, child): """Remove a child Container. Intended for use in subclasses that allow dynamic addition of child Containers.""" if not isinstance(child, AbstractContainer): - raise ValueError('Cannot remove non-AbstractContainer object from children.') + raise ValueError("Cannot remove non-AbstractContainer object from children.") if child not in self.children: - raise ValueError("%s '%s' is not a child of %s '%s'." % (child.__class__.__name__, child.name, - self.__class__.__name__, self.name)) + raise ValueError( + "%s '%s' is not a child of %s '%s'." + % (child.__class__.__name__, child.name, self.__class__.__name__, self.name) + ) child.__parent = None self.__children.remove(child) child.set_modified() @@ -383,7 +415,7 @@ def reset_parent(self): class Container(AbstractContainer): """A container that can contain other containers and has special functionality for printing.""" - _pconf_allowed_keys = {'name', 'child', 'required_name', 'doc', 'settable'} + _pconf_allowed_keys = {"name", "child", "required_name", "doc", "settable"} @classmethod def _setter(cls, field): @@ -392,26 +424,31 @@ def _setter(cls, field): ret = [super_setter] # create setter with check for required name # the AbstractContainer that is passed to the setter must have name = required_name - if field.get('required_name', None) is not None: - required_name = field['required_name'] + if field.get("required_name", None) is not None: + required_name = field["required_name"] idx1 = len(ret) - 1 def container_setter(self, val): if val is not None: if not isinstance(val, AbstractContainer): - msg = ("Field '%s' on %s has a required name and must be a subclass of AbstractContainer." - % (field['name'], self.__class__.__name__)) + msg = "Field '%s' on %s has a required name and must be a subclass of AbstractContainer." % ( + field["name"], + self.__class__.__name__, + ) raise ValueError(msg) if val.name != required_name: - msg = ("Field '%s' on %s must be named '%s'." - % (field['name'], self.__class__.__name__, required_name)) + msg = "Field '%s' on %s must be named '%s'." % ( + field["name"], + self.__class__.__name__, + required_name, + ) raise ValueError(msg) ret[idx1](self, val) # call the previous setter ret.append(container_setter) # create setter that accepts a value or tuple, list, or dict or values and sets the value's parent to self - if field.get('child', False): + if field.get("child", False): idx2 = len(ret) - 1 def container_setter(self, val): @@ -442,7 +479,7 @@ def __repr__(self): for k in sorted(self.fields): # sorted to enable tests v = self.fields[k] # if isinstance(v, DataIO) or not hasattr(v, '__len__') or len(v) > 0: - if hasattr(v, '__len__'): + if hasattr(v, "__len__"): if isinstance(v, (np.ndarray, list, tuple)): if len(v) > 0: template += " {}: {}\n".format(k, self.__smart_str(v, 1)) @@ -479,53 +516,53 @@ def __smart_str(v, num_indent): if isinstance(v, list) or isinstance(v, tuple): if len(v) and isinstance(v[0], AbstractContainer): - return Container.__smart_str_list(v, num_indent, '(') + return Container.__smart_str_list(v, num_indent, "(") try: return str(np.asarray(v)) except ValueError: - return Container.__smart_str_list(v, num_indent, '(') + return Container.__smart_str_list(v, num_indent, "(") elif isinstance(v, dict): return Container.__smart_str_dict(v, num_indent) elif isinstance(v, set): - return Container.__smart_str_list(sorted(list(v)), num_indent, '{') + return Container.__smart_str_list(sorted(list(v)), num_indent, "{") elif isinstance(v, AbstractContainer): - return "{} {}".format(getattr(v, 'name'), type(v)) + return "{} {}".format(getattr(v, "name"), type(v)) else: return str(v) @staticmethod def __smart_str_list(str_list, num_indent, left_br): - if left_br == '(': - right_br = ')' - if left_br == '{': - right_br = '}' + if left_br == "(": + right_br = ")" + if left_br == "{": + right_br = "}" if len(str_list) == 0: - return left_br + ' ' + right_br - indent = num_indent * 2 * ' ' - indent_in = (num_indent + 1) * 2 * ' ' + return left_br + " " + right_br + indent = num_indent * 2 * " " + indent_in = (num_indent + 1) * 2 * " " out = left_br for v in str_list[:-1]: - out += '\n' + indent_in + Container.__smart_str(v, num_indent + 1) + ',' + out += "\n" + indent_in + Container.__smart_str(v, num_indent + 1) + "," if str_list: - out += '\n' + indent_in + Container.__smart_str(str_list[-1], num_indent + 1) - out += '\n' + indent + right_br + out += "\n" + indent_in + Container.__smart_str(str_list[-1], num_indent + 1) + out += "\n" + indent + right_br return out @staticmethod def __smart_str_dict(d, num_indent): - left_br = '{' - right_br = '}' + left_br = "{" + right_br = "}" if len(d) == 0: - return left_br + ' ' + right_br - indent = num_indent * 2 * ' ' - indent_in = (num_indent + 1) * 2 * ' ' + return left_br + " " + right_br + indent = num_indent * 2 * " " + indent_in = (num_indent + 1) * 2 * " " out = left_br keys = sorted(list(d.keys())) for k in keys[:-1]: - out += '\n' + indent_in + Container.__smart_str(k, num_indent + 1) + ' ' + str(type(d[k])) + ',' + out += "\n" + indent_in + Container.__smart_str(k, num_indent + 1) + " " + str(type(d[k])) + "," if keys: - out += '\n' + indent_in + Container.__smart_str(keys[-1], num_indent + 1) + ' ' + str(type(d[keys[-1]])) - out += '\n' + indent + right_br + out += "\n" + indent_in + Container.__smart_str(keys[-1], num_indent + 1) + " " + str(type(d[keys[-1]])) + out += "\n" + indent + right_br return out @@ -534,10 +571,12 @@ class Data(AbstractContainer): A class for representing dataset containers """ - @docval({'name': 'name', 'type': str, 'doc': 'the name of this container'}, - {'name': 'data', 'type': ('scalar_data', 'array_data', 'data'), 'doc': 'the source of the data'}) + @docval( + {"name": "name", "type": str, "doc": "the name of this container"}, + {"name": "data", "type": ("scalar_data", "array_data", "data"), "doc": "the source of the data"}, + ) def __init__(self, **kwargs): - data = popargs('data', kwargs) + data = popargs("data", kwargs) super().__init__(**kwargs) self.__data = data @@ -554,16 +593,22 @@ def shape(self): """ return get_data_shape(self.__data) - @docval({'name': 'dataio', 'type': DataIO, 'doc': 'the DataIO to apply to the data held by this Data'}) + @docval( + { + "name": "dataio", + "type": DataIO, + "doc": "the DataIO to apply to the data held by this Data", + } + ) def set_dataio(self, **kwargs): """ Apply DataIO object to the data held by this Data object """ - dataio = getargs('dataio', kwargs) + dataio = getargs("dataio", kwargs) dataio.data = self.__data self.__data = dataio - @docval({'name': 'func', 'type': types.FunctionType, 'doc': 'a function to transform *data*'}) + @docval({"name": "func", "type": types.FunctionType, "doc": "a function to transform *data*"}) def transform(self, **kwargs): """ Transform data from the current underlying state. @@ -571,7 +616,7 @@ def transform(self, **kwargs): This function can be used to permanently load data from disk, or convert to a different representation, such as a torch.Tensor """ - func = getargs('func', kwargs) + func = getargs("func", kwargs) self.__data = func(self.__data) return self @@ -611,21 +656,20 @@ def extend(self, arg): class DataRegion(Data): - @property @abstractmethod def data(self): - ''' + """ The target data that this region applies to - ''' + """ pass @property @abstractmethod def region(self): - ''' + """ The region that indexes into data e.g. slice or list of indices - ''' + """ pass @@ -647,9 +691,12 @@ class MultiContainerInterface(Container): def __new__(cls, *args, **kwargs): if cls is MultiContainerInterface: raise TypeError("Can't instantiate class MultiContainerInterface.") - if not hasattr(cls, '__clsconf__'): - raise TypeError("MultiContainerInterface subclass %s is missing __clsconf__ attribute. Please check that " - "the class is properly defined." % cls.__name__) + if not hasattr(cls, "__clsconf__"): + raise TypeError( + "MultiContainerInterface subclass %s is missing __clsconf__ attribute." + " Please check that the class is properly defined." + % cls.__name__ + ) return super().__new__(cls, *args, **kwargs) @staticmethod @@ -658,9 +705,9 @@ def __add_article(noun): noun = noun[0] if isinstance(noun, type): noun = noun.__name__ - if noun[0] in ('aeiouAEIOU'): - return 'an %s' % noun - return 'a %s' % noun + if noun[0] in ("aeiouAEIOU"): + return "an %s" % noun + return "a %s" % noun @staticmethod def __join(argtype): @@ -683,7 +730,7 @@ def tostr(x): if len(args_str) == 2: return " or ".join(tostr(x) for x in args_str) else: - return ", ".join(tostr(x) for x in args_str[:-1]) + ', or ' + args_str[-1] + return ", ".join(tostr(x) for x in args_str[:-1]) + ", or " + args_str[-1] else: return tostr(argtype) @@ -691,18 +738,24 @@ def tostr(x): def __make_get(cls, func_name, attr_name, container_type): doc = "Get %s from this %s" % (cls.__add_article(container_type), cls.__name__) - @docval({'name': 'name', 'type': str, 'doc': 'the name of the %s' % cls.__join(container_type), - 'default': None}, - rtype=container_type, returns='the %s with the given name' % cls.__join(container_type), - func_name=func_name, doc=doc) + @docval( + {"name": "name", "type": str, "doc": "the name of the %s" % cls.__join(container_type), "default": None}, + rtype=container_type, + returns="the %s with the given name" % cls.__join(container_type), + func_name=func_name, + doc=doc, + ) def _func(self, **kwargs): - name = getargs('name', kwargs) + name = getargs("name", kwargs) d = getattr(self, attr_name) ret = None if name is None: if len(d) > 1: - msg = ("More than one element in %s of %s '%s' -- must specify a name." - % (attr_name, cls.__name__, self.name)) + msg = "More than one element in %s of %s '%s' -- must specify a name." % ( + attr_name, + cls.__name__, + self.name, + ) raise ValueError(msg) elif len(d) == 0: msg = "%s of %s '%s' is empty." % (attr_name, cls.__name__, self.name) @@ -723,19 +776,25 @@ def _func(self, **kwargs): def __make_getitem(cls, attr_name, container_type): doc = "Get %s from this %s" % (cls.__add_article(container_type), cls.__name__) - @docval({'name': 'name', 'type': str, 'doc': 'the name of the %s' % cls.__join(container_type), - 'default': None}, - rtype=container_type, returns='the %s with the given name' % cls.__join(container_type), - func_name='__getitem__', doc=doc) + @docval( + {"name": "name", "type": str, "doc": "the name of the %s" % cls.__join(container_type), "default": None}, + rtype=container_type, + returns="the %s with the given name" % cls.__join(container_type), + func_name="__getitem__", + doc=doc, + ) def _func(self, **kwargs): # NOTE this is the same code as the getter but with different error messages - name = getargs('name', kwargs) + name = getargs("name", kwargs) d = getattr(self, attr_name) ret = None if name is None: if len(d) > 1: - msg = ("More than one %s in %s '%s' -- must specify a name." - % (cls.__join(container_type), cls.__name__, self.name)) + msg = "More than one %s in %s '%s' -- must specify a name." % ( + cls.__join(container_type), + cls.__name__, + self.name, + ) raise ValueError(msg) elif len(d) == 0: msg = "%s '%s' is empty." % (cls.__name__, self.name) @@ -756,9 +815,15 @@ def _func(self, **kwargs): def __make_add(cls, func_name, attr_name, container_type): doc = "Add one or multiple %s objects to this %s" % (cls.__join(container_type), cls.__name__) - @docval({'name': attr_name, 'type': (list, tuple, dict, container_type), - 'doc': 'one or multiple %s objects to add to this %s' % (cls.__join(container_type), cls.__name__)}, - func_name=func_name, doc=doc) + @docval( + { + "name": attr_name, + "type": (list, tuple, dict, container_type), + "doc": "one or multiple %s objects to add to this %s" % (cls.__join(container_type), cls.__name__), + }, + func_name=func_name, + doc=doc, + ) def _func(self, **kwargs): container = getargs(attr_name, kwargs) if isinstance(container, container_type): @@ -787,8 +852,13 @@ def _func(self, **kwargs): def __make_create(cls, func_name, add_name, container_type): doc = "Create %s object and add it to this %s" % (cls.__add_article(container_type), cls.__name__) - @docval(*get_docval(container_type.__init__), func_name=func_name, doc=doc, - returns="the %s object that was created" % cls.__join(container_type), rtype=container_type) + @docval( + *get_docval(container_type.__init__), + func_name=func_name, + doc=doc, + returns="the %s object that was created" % cls.__join(container_type), + rtype=container_type, + ) def _func(self, **kwargs): ret = container_type(**kwargs) getattr(self, add_name)(ret) @@ -800,19 +870,25 @@ def _func(self, **kwargs): def __make_constructor(cls, clsconf): args = list() for conf in clsconf: - attr_name = conf['attr'] - container_type = conf['type'] - args.append({'name': attr_name, 'type': (list, tuple, dict, container_type), - 'doc': '%s to store in this interface' % cls.__join(container_type), 'default': dict()}) + attr_name = conf["attr"] + container_type = conf["type"] + args.append( + { + "name": attr_name, + "type": (list, tuple, dict, container_type), + "doc": "%s to store in this interface" % cls.__join(container_type), + "default": dict(), + } + ) - args.append({'name': 'name', 'type': str, 'doc': 'the name of this container', 'default': cls.__name__}) + args.append({"name": "name", "type": str, "doc": "the name of this container", "default": cls.__name__}) - @docval(*args, func_name='__init__') + @docval(*args, func_name="__init__") def _func(self, **kwargs): - super().__init__(name=kwargs['name']) + super().__init__(name=kwargs["name"]) for conf in clsconf: - attr_name = conf['attr'] - add_name = conf['add'] + attr_name = conf["attr"] + add_name = conf["add"] container = popargs(attr_name, kwargs) add = getattr(self, add_name) add(container) @@ -828,9 +904,11 @@ def _func(self): # do this here to avoid creating default __init__ which may or may not be overridden in # custom classes and dynamically generated classes if attr not in self.fields: + def _remove_child(child): if child.parent is self: self._remove_child(child) + self.fields[attr] = LabelledDict(attr, remove_callable=_remove_child) return self.fields.get(attr) @@ -841,9 +919,9 @@ def _remove_child(child): def __make_setter(cls, add_name): """Make a setter function for creating a :py:func:`property`""" - @docval({'name': 'val', 'type': (list, tuple, dict), 'doc': 'the sub items to add', 'default': None}) + @docval({"name": "val", "type": (list, tuple, dict), "doc": "the sub items to add", "default": None}) def _func(self, **kwargs): - val = getargs('val', kwargs) + val = getargs("val", kwargs) if val is None: return getattr(self, add_name)(val) @@ -855,7 +933,7 @@ def __build_class(cls, name, bases, classdict): """Verify __clsconf__ and create methods based on __clsconf__. This method is called prior to __new__ and __init__ during class declaration in the metaclass. """ - if not hasattr(cls, '__clsconf__'): + if not hasattr(cls, "__clsconf__"): return multi = False @@ -865,27 +943,29 @@ def __build_class(cls, name, bases, classdict): multi = True clsconf = cls.__clsconf__ else: - raise TypeError("'__clsconf__' for MultiContainerInterface subclass %s must be a dict or a list of " - "dicts." % cls.__name__) + raise TypeError( + "'__clsconf__' for MultiContainerInterface subclass %s must be a dict or a list of dicts." + % cls.__name__ + ) for conf_index, conf_dict in enumerate(clsconf): cls.__build_conf_methods(conf_dict, conf_index, multi) # make __getitem__ (square bracket access) only if one conf type is defined if len(clsconf) == 1: - attr = clsconf[0].get('attr') - container_type = clsconf[0].get('type') - setattr(cls, '__getitem__', cls.__make_getitem(attr, container_type)) + attr = clsconf[0].get("attr") + container_type = clsconf[0].get("type") + setattr(cls, "__getitem__", cls.__make_getitem(attr, container_type)) # create the constructor, only if it has not been overridden # i.e. it is the same method as the parent class constructor - if '__init__' not in cls.__dict__: - setattr(cls, '__init__', cls.__make_constructor(clsconf)) + if "__init__" not in cls.__dict__: + setattr(cls, "__init__", cls.__make_constructor(clsconf)) @classmethod def __build_conf_methods(cls, conf_dict, conf_index, multi): # get add method name - add = conf_dict.get('add') + add = conf_dict.get("add") if add is None: msg = "MultiContainerInterface subclass %s is missing 'add' key in __clsconf__" % cls.__name__ if multi: @@ -893,7 +973,7 @@ def __build_conf_methods(cls, conf_dict, conf_index, multi): raise ValueError(msg) # get container attribute name - attr = conf_dict.get('attr') + attr = conf_dict.get("attr") if attr is None: msg = "MultiContainerInterface subclass %s is missing 'attr' key in __clsconf__" % cls.__name__ if multi: @@ -901,7 +981,7 @@ def __build_conf_methods(cls, conf_dict, conf_index, multi): raise ValueError(msg) # get container type - container_type = conf_dict.get('type') + container_type = conf_dict.get("type") if container_type is None: msg = "MultiContainerInterface subclass %s is missing 'type' key in __clsconf__" % cls.__name__ if multi: @@ -919,19 +999,23 @@ def __build_conf_methods(cls, conf_dict, conf_index, multi): setattr(cls, add, cls.__make_add(add, attr, container_type)) # create the create method, only if a single container type is specified - create = conf_dict.get('create') + create = conf_dict.get("create") if create is not None: if isinstance(container_type, type): setattr(cls, create, cls.__make_create(create, add, container_type)) else: - msg = ("Cannot specify 'create' key in __clsconf__ for MultiContainerInterface subclass %s " - "when 'type' key is not a single type") % cls.__name__ + msg = ( + "Cannot specify 'create' key in __clsconf__ for" + " MultiContainerInterface subclass %s when 'type' key is not a" + " single type" + % cls.__name__ + ) if multi: msg += " at index %d" % conf_index raise ValueError(msg) # create the get method - get = conf_dict.get('get') + get = conf_dict.get("get") if get is not None: setattr(cls, get, cls.__make_get(get, attr, container_type)) @@ -977,23 +1061,30 @@ def table(self, val): @ExtenderMeta.pre_init def __build_row_class(cls, name, bases, classdict): - table_cls = getattr(cls, '__table__', None) + table_cls = getattr(cls, "__table__", None) if table_cls is not None: - columns = getattr(table_cls, '__columns__') + columns = getattr(table_cls, "__columns__") if cls.__init__ == bases[-1].__init__: # check if __init__ is overridden columns = deepcopy(columns) func_args = list() for col in columns: func_args.append(col) - func_args.append({'name': 'table', 'type': Table, 'default': None, - 'help': 'the table this row is from'}) - func_args.append({'name': 'idx', 'type': int, 'default': None, - 'help': 'the index for this row'}) + func_args.append( + {"name": "table", "type": Table, "default": None, "help": "the table this row is from"} + ) + func_args.append( + { + "name": "idx", + "type": int, + "default": None, + "help": "the index for this row", + } + ) @docval(*func_args) def __init__(self, **kwargs): super(cls, self).__init__() - table, idx = popargs('table', 'idx', kwargs) + table, idx = popargs("table", "idx", kwargs) self.__keys = list() self.__idx = None self.__table = None @@ -1003,18 +1094,18 @@ def __init__(self, **kwargs): self.idx = idx self.table = table - setattr(cls, '__init__', __init__) + setattr(cls, "__init__", __init__) def todict(self): return {k: getattr(self, k) for k in self.__keys} - setattr(cls, 'todict', todict) + setattr(cls, "todict", todict) # set this so Table.row gets set when a Table is instantiated table_cls.__rowclass__ = cls else: if bases != (object,): - raise ValueError('__table__ must be set if sub-classing Row') + raise ValueError("__table__ must be set if sub-classing Row") def __eq__(self, other): return self.idx == other.idx and self.table is other.table @@ -1043,7 +1134,7 @@ def __getitem__(self, idx): class Table(Data): - r''' + r""" Subclasses should specify the class attribute \_\_columns\_\_. This should be a list of dictionaries with the following keys: @@ -1064,7 +1155,7 @@ class Table(Data): A Table class can be paired with a Row class for conveniently working with rows of a Table. This pairing must be indicated in the Row class implementation. See Row for more details. - ''' + """ # This class attribute is used to indicate which Row class should be used when # adding RowGetter functionality to the Table. @@ -1072,29 +1163,35 @@ class Table(Data): @ExtenderMeta.pre_init def __build_table_class(cls, name, bases, classdict): - if hasattr(cls, '__columns__'): - columns = getattr(cls, '__columns__') + if hasattr(cls, "__columns__"): + columns = getattr(cls, "__columns__") idx = dict() for i, col in enumerate(columns): - idx[col['name']] = i - setattr(cls, '__colidx__', idx) + idx[col["name"]] = i + setattr(cls, "__colidx__", idx) if cls.__init__ == bases[-1].__init__: # check if __init__ is overridden - name = {'name': 'name', 'type': str, 'doc': 'the name of this table'} - defname = getattr(cls, '__defaultname__', None) + name = {"name": "name", "type": str, "doc": "the name of this table"} + defname = getattr(cls, "__defaultname__", None) if defname is not None: - name['default'] = defname # override the name with the default name if present - - @docval(name, - {'name': 'data', 'type': ('array_data', 'data'), 'doc': 'the data in this table', - 'default': list()}) + name["default"] = defname # override the name with the default name if present + + @docval( + name, + { + "name": "data", + "type": ("array_data", "data"), + "doc": "the data in this table", + "default": list(), + }, + ) def __init__(self, **kwargs): - name, data = getargs('name', 'data', kwargs) - colnames = [i['name'] for i in columns] + name, data = getargs("name", "data", kwargs) + colnames = [i["name"] for i in columns] super(cls, self).__init__(colnames, name, data) - setattr(cls, '__init__', __init__) + setattr(cls, "__init__", __init__) if cls.add_row == bases[-1].add_row: # check if add_row is overridden @@ -1102,15 +1199,30 @@ def __init__(self, **kwargs): def add_row(self, **kwargs): return super(cls, self).add_row(kwargs) - setattr(cls, 'add_row', add_row) + setattr(cls, "add_row", add_row) - @docval({'name': 'columns', 'type': (list, tuple), 'doc': 'a list of the columns in this table'}, - {'name': 'name', 'type': str, 'doc': 'the name of this container'}, - {'name': 'data', 'type': ('array_data', 'data'), 'doc': 'the source of the data', 'default': list()}) + @docval( + { + "name": "columns", + "type": (list, tuple), + "doc": "a list of the columns in this table", + }, + { + "name": "name", + "type": str, + "doc": "the name of this container", + }, + { + "name": "data", + "type": ("array_data", "data"), + "doc": "the source of the data", + "default": list(), + }, + ) def __init__(self, **kwargs): - self.__columns = tuple(popargs('columns', kwargs)) + self.__columns = tuple(popargs("columns", kwargs)) self.__col_index = {name: idx for idx, name in enumerate(self.__columns)} - if getattr(self, '__rowclass__') is not None: + if getattr(self, "__rowclass__") is not None: self.row = RowGetter(self) super().__init__(**kwargs) @@ -1118,11 +1230,11 @@ def __init__(self, **kwargs): def columns(self): return self.__columns - @docval({'name': 'values', 'type': dict, 'doc': 'the values for each column'}) + @docval({"name": "values", "type": dict, "doc": "the values for each column"}) def add_row(self, **kwargs): - values = getargs('values', kwargs) + values = getargs("values", kwargs) if not isinstance(self.data, list): - msg = 'Cannot append row to %s' % type(self.data) + msg = "Cannot append row to %s" % type(self.data) raise ValueError(msg) ret = len(self.data) row = [values[col] for col in self.columns] @@ -1131,9 +1243,9 @@ def add_row(self, **kwargs): return ret def which(self, **kwargs): - ''' + """ Query a table - ''' + """ if len(kwargs) != 1: raise ValueError("only one column can be queried") colname, value = kwargs.popitem() @@ -1162,7 +1274,7 @@ def __getitem__(self, args): elif isinstance(args[0], int): col = args[0] else: - raise KeyError('first argument must be a column name or index') + raise KeyError("first argument must be a column name or index") return self.data[idx][col] elif isinstance(args, str): col = self.__col_index.get(args) @@ -1173,31 +1285,35 @@ def __getitem__(self, args): return self.data[idx] def to_dataframe(self): - '''Produce a pandas DataFrame containing this table's data. - ''' + """Produce a pandas DataFrame containing this table's data.""" data = {colname: self[colname] for ii, colname in enumerate(self.columns)} return pd.DataFrame(data) @classmethod @docval( - {'name': 'df', 'type': pd.DataFrame, 'doc': 'input data'}, - {'name': 'name', 'type': str, 'doc': 'the name of this container', 'default': None}, + {"name": "df", "type": pd.DataFrame, "doc": "input data"}, { - 'name': 'extra_ok', - 'type': bool, - 'doc': 'accept (and ignore) unexpected columns on the input dataframe', - 'default': False + "name": "name", + "type": str, + "doc": "the name of this container", + "default": None, + }, + { + "name": "extra_ok", + "type": bool, + "doc": "accept (and ignore) unexpected columns on the input dataframe", + "default": False, }, ) def from_dataframe(cls, **kwargs): - '''Construct an instance of Table (or a subclass) from a pandas DataFrame. The columns of the dataframe + """Construct an instance of Table (or a subclass) from a pandas DataFrame. The columns of the dataframe should match the columns defined on the Table subclass. - ''' + """ - df, name, extra_ok = getargs('df', 'name', 'extra_ok', kwargs) + df, name, extra_ok = getargs("df", "name", "extra_ok", kwargs) - cls_cols = list([col['name'] for col in getattr(cls, '__columns__')]) + cls_cols = list([col["name"] for col in getattr(cls, "__columns__")]) df_cols = list(df.columns) missing_columns = set(cls_cols) - set(df_cols) @@ -1205,7 +1321,7 @@ def from_dataframe(cls, **kwargs): if extra_columns: raise ValueError( - 'unrecognized column(s) {} for table class {} (columns {})'.format( + "unrecognized column(s) {} for table class {} (columns {})".format( extra_columns, cls.__name__, cls_cols ) ) @@ -1216,7 +1332,7 @@ def from_dataframe(cls, **kwargs): elif missing_columns: raise ValueError( - 'missing column(s) {} for table class {} (columns {}, provided {})'.format( + "missing column(s) {} for table class {} (columns {}, provided {})".format( missing_columns, cls.__name__, cls_cols, df_cols ) ) @@ -1224,10 +1340,7 @@ def from_dataframe(cls, **kwargs): data = [] for index, row in df.iterrows(): if use_index: - data.append([ - row[colname] if colname != df.index.name else index - for colname in cls_cols - ]) + data.append([row[colname] if colname != df.index.name else index for colname in cls_cols]) else: data.append(tuple([row[colname] for colname in cls_cols])) diff --git a/src/hdmf/data_utils.py b/src/hdmf/data_utils.py index 967663689..dc032dfa3 100644 --- a/src/hdmf/data_utils.py +++ b/src/hdmf/data_utils.py @@ -1,17 +1,17 @@ import copy -import math import functools # TODO: remove when Python 3.7 support is dropped - see #785 +import math import operator # TODO: remove when Python 3.7 support is dropped from abc import ABCMeta, abstractmethod from collections.abc import Iterable -from warnings import warn +from itertools import chain, product from typing import Tuple -from itertools import product, chain +from warnings import warn import h5py import numpy as np -from .utils import docval, getargs, popargs, docval_macro, get_data_shape +from .utils import docval, docval_macro, get_data_shape, getargs, popargs def append_data(data, arg): @@ -19,7 +19,7 @@ def append_data(data, arg): data.append(arg) return data elif isinstance(data, np.ndarray): - return np.append(data, np.expand_dims(arg, axis=0), axis=0) + return np.append(data, np.expand_dims(arg, axis=0), axis=0) elif isinstance(data, h5py.Dataset): shape = list(data.shape) shape[0] += 1 @@ -46,14 +46,14 @@ def extend_data(data, arg): shape = list(data.shape) shape[0] += len(arg) data.resize(shape) - data[-len(arg):] = arg + data[-len(arg) :] = arg return data else: msg = "Data cannot extend object of type '%s'" % type(data) raise ValueError(msg) -@docval_macro('array_data') +@docval_macro("array_data") class AbstractDataChunkIterator(metaclass=ABCMeta): """ Abstract iterator class used to iterate over DataChunks. @@ -138,9 +138,8 @@ class GenericDataChunkIterator(AbstractDataChunkIterator): name="buffer_gb", type=(float, int), doc=( - "If buffer_shape is not specified, it will be inferred as the smallest chunk " - "below the buffer_gb threshold." - "Defaults to 1GB." + "If buffer_shape is not specified, it will be inferred as the smallest" + " chunk below the buffer_gb threshold.Defaults to 1GB." ), default=None, ), @@ -154,8 +153,10 @@ class GenericDataChunkIterator(AbstractDataChunkIterator): name="chunk_mb", type=(float, int), doc=( - "If chunk_shape is not specified, it will be inferred as the smallest chunk " - "below the chunk_mb threshold.", + ( + "If chunk_shape is not specified, it will be inferred as the" + " smallest chunk below the chunk_mb threshold." + ), "Defaults to 1MB.", ), default=None, @@ -193,9 +194,9 @@ def __init__(self, **kwargs): See https://support.hdfgroup.org/HDF5/doc/TechNotes/TechNote-HDF5-ImprovingIOPerformanceCompressedDatasets.pdf for more details. """ - buffer_gb, buffer_shape, chunk_mb, chunk_shape, self.display_progress, self.progress_bar_options = getargs( - "buffer_gb", "buffer_shape", "chunk_mb", "chunk_shape", "display_progress", "progress_bar_options", kwargs - ) + buffer_gb, buffer_shape = getargs("buffer_gb", "buffer_shape", kwargs) + chunk_mb, chunk_shape = getargs("chunk_mb", "chunk_shape", kwargs) + self.display_progress, self.progress_bar_options = getargs("display_progress", "progress_bar_options", kwargs) if buffer_gb is None and buffer_shape is None: buffer_gb = 1.0 @@ -248,12 +249,7 @@ def __init__(self, **kwargs): 1, ) self.buffer_selection_generator = ( - tuple( - [ - slice(lower_bound, upper_bound) - for lower_bound, upper_bound in zip(lower_bounds, upper_bounds) - ] - ) + tuple([slice(lower_bound, upper_bound) for lower_bound, upper_bound in zip(lower_bounds, upper_bounds)]) for lower_bounds, upper_bounds in zip( product( *[ @@ -263,7 +259,14 @@ def __init__(self, **kwargs): ), product( *[ - chain(range(buffer_shape_axis, max_shape_axis, buffer_shape_axis), [max_shape_axis]) + chain( + range( + buffer_shape_axis, + max_shape_axis, + buffer_shape_axis, + ), + [max_shape_axis], + ) for max_shape_axis, buffer_shape_axis in zip(self.maxshape, self.buffer_shape) ] ), @@ -283,8 +286,8 @@ def __init__(self, **kwargs): self.progress_bar = tqdm(total=self.num_buffers, **self.progress_bar_options) except ImportError: warn( - "You must install tqdm to use the progress bar feature (pip install tqdm)! " - "Progress bar is disabled." + "You must install tqdm to use the progress bar feature (pip install" + " tqdm)! Progress bar is disabled." ) self.display_progress = False @@ -340,23 +343,21 @@ def _get_default_buffer_shape(self, **kwargs) -> Tuple[int, ...]: """ buffer_gb = getargs("buffer_gb", kwargs) assert buffer_gb > 0, f"buffer_gb ({buffer_gb}) must be greater than zero!" - assert all(chunk_axis > 0 for chunk_axis in self.chunk_shape), ( - f"Some dimensions of chunk_shape ({self.chunk_shape}) are less than zero!" - ) + assert all( + chunk_axis > 0 for chunk_axis in self.chunk_shape + ), f"Some dimensions of chunk_shape ({self.chunk_shape}) are less than zero!" k = math.floor( ( # TODO: replace with below when Python 3.7 support is dropped # buffer_gb * 1e9 / (math.prod(self.chunk_shape) * self.dtype.itemsize) - buffer_gb * 1e9 / (functools.reduce(operator.mul, self.chunk_shape, 1) * self.dtype.itemsize) - ) ** (1 / len(self.chunk_shape)) - ) - return tuple( - [ - min(max(k * x, self.chunk_shape[j]), self.maxshape[j]) - for j, x in enumerate(self.chunk_shape) - ] + buffer_gb + * 1e9 + / (functools.reduce(operator.mul, self.chunk_shape, 1) * self.dtype.itemsize) + ) + ** (1 / len(self.chunk_shape)) ) + return tuple([min(max(k * x, self.chunk_shape[j]), self.maxshape[j]) for j, x in enumerate(self.chunk_shape)]) def recommended_chunk_shape(self) -> Tuple[int, ...]: return self.chunk_shape @@ -378,7 +379,10 @@ def __next__(self): self.progress_bar.update(n=1) try: buffer_selection = next(self.buffer_selection_generator) - return DataChunk(data=self._get_data(selection=buffer_selection), selection=buffer_selection) + return DataChunk( + data=self._get_data(selection=buffer_selection), + selection=buffer_selection, + ) except StopIteration: if self.display_progress: self.progress_bar.write("\n") # Allows text to be written to new lines after completion @@ -445,13 +449,36 @@ class DataChunkIterator(AbstractDataChunkIterator): """ __docval_init = ( - {'name': 'data', 'type': None, 'doc': 'The data object used for iteration', 'default': None}, - {'name': 'maxshape', 'type': tuple, - 'doc': 'The maximum shape of the full data array. Use None to indicate unlimited dimensions', - 'default': None}, - {'name': 'dtype', 'type': np.dtype, 'doc': 'The Numpy data type for the array', 'default': None}, - {'name': 'buffer_size', 'type': int, 'doc': 'Number of values to be buffered in a chunk', 'default': 1}, - {'name': 'iter_axis', 'type': int, 'doc': 'The dimension to iterate over', 'default': 0} + { + "name": "data", + "type": None, + "doc": "The data object used for iteration", + "default": None, + }, + { + "name": "maxshape", + "type": tuple, + "doc": "The maximum shape of the full data array. Use None to indicate unlimited dimensions", + "default": None, + }, + { + "name": "dtype", + "type": np.dtype, + "doc": "The Numpy data type for the array", + "default": None, + }, + { + "name": "buffer_size", + "type": int, + "doc": "Number of values to be buffered in a chunk", + "default": 1, + }, + { + "name": "iter_axis", + "type": int, + "doc": "The dimension to iterate over", + "default": 0, + }, ) @docval(*__docval_init) @@ -461,19 +488,22 @@ def __init__(self, **kwargs): the dtype of the data. """ # Get the user parameters - self.data, self.__maxshape, self.__dtype, self.buffer_size, self.iter_axis = getargs('data', - 'maxshape', - 'dtype', - 'buffer_size', - 'iter_axis', - kwargs) + ( + self.data, + self.__maxshape, + self.__dtype, + self.buffer_size, + self.iter_axis, + ) = getargs("data", "maxshape", "dtype", "buffer_size", "iter_axis", kwargs) self.chunk_index = 0 # Create an iterator for the data if possible if isinstance(self.data, Iterable): if self.iter_axis != 0 and isinstance(self.data, (list, tuple)): - warn('Iterating over an axis other than the first dimension of list or tuple data ' - 'involves converting the data object to a numpy ndarray, which may incur a computational ' - 'cost.') + warn( + "Iterating over an axis other than the first dimension of list or" + " tuple data involves converting the data object to a numpy" + " ndarray, which may incur a computational cost." + ) self.data = np.asarray(self.data) if isinstance(self.data, np.ndarray): # iterate over the given axis by adding a new view on data (iter only works on the first dim) @@ -492,7 +522,11 @@ def __init__(self, **kwargs): self.__maxshape = self.data.shape # Avoid the special case of scalar values by making them into a 1D numpy array if len(self.__maxshape) == 0: - self.data = np.asarray([self.data, ]) + self.data = np.asarray( + [ + self.data, + ] + ) self.__maxshape = self.data.shape self.__data_iter = iter(self.data) # Try to get an accurate idea of __maxshape for other Python data structures if possible. @@ -514,7 +548,7 @@ def __init__(self, **kwargs): self.__first_chunk_shape = tuple(1 if i is None else i for i in self.__maxshape) if self.__dtype is None: - raise Exception('Data type could not be determined. Please specify dtype in DataChunkIterator init.') + raise Exception("Data type could not be determined. Please specify dtype in DataChunkIterator init.") @classmethod @docval(*__docval_init) @@ -531,6 +565,7 @@ def _read_next_chunk(self): :returns: self.__next_chunk, i.e., the DataChunk object describing the next chunk """ from h5py import Dataset as H5Dataset + if isinstance(self.data, H5Dataset): start_index = self.chunk_index * self.buffer_size stop_index = start_index + self.buffer_size @@ -585,8 +620,10 @@ def _read_next_chunk(self): self.__next_chunk.data = np.stack(iter_pieces, axis=self.iter_axis) selection = [slice(None)] * len(self.maxshape) - selection[self.iter_axis] = slice(self.__next_chunk_start + curr_chunk_offset, - self.__next_chunk_start + curr_chunk_offset + next_chunk_size) + selection[self.iter_axis] = slice( + self.__next_chunk_start + curr_chunk_offset, + self.__next_chunk_start + curr_chunk_offset + next_chunk_size, + ) self.__next_chunk.selection = tuple(selection) # next chunk should start at self.__next_chunk.selection[self.iter_axis].stop @@ -623,8 +660,7 @@ def __next__(self): if self.__first_chunk_shape is None: self.__first_chunk_shape = self.__next_chunk.data.shape # Keep the next chunk we need to return - curr_chunk = DataChunk(self.__next_chunk.data, - self.__next_chunk.selection) + curr_chunk = DataChunk(self.__next_chunk.data, self.__next_chunk.selection) # Remove the data for the next chunk from our list since we are returning it here. # This is to allow the GarbageCollector to remove the data when it goes out of scope and avoid # having 2 full chunks in memory if not necessary @@ -634,19 +670,22 @@ def __next__(self): next = __next__ - @docval(returns='Tuple with the recommended chunk shape or None if no particular shape is recommended.') + @docval(returns="Tuple with the recommended chunk shape or None if no particular shape is recommended.") def recommended_chunk_shape(self): """Recommend a chunk shape. To optimize iterative write the chunk should be aligned with the common shape of chunks returned by __next__ or if those chunks are too large, then a well-aligned subset of those chunks. This may also be any other value in case one wants to recommend chunk shapes to optimize read rather - than write. The default implementation returns None, indicating no preferential chunking option.""" + than write. The default implementation returns None, indicating no preferential chunking option. + """ return None - @docval(returns='Recommended initial shape for the full data. This should be the shape of the full dataset' + - 'if known beforehand or alternatively the minimum shape of the dataset. Return None if no ' + - 'recommendation is available') + @docval( + returns="Recommended initial shape for the full data. This should be the shape of the full dataset" + + "if known beforehand or alternatively the minimum shape of the dataset. Return None if no " + + "recommendation is available" + ) def recommended_data_shape(self): """Recommend an initial shape of the data. This is useful when progressively writing data and we want to recommend an initial size for the dataset""" @@ -686,7 +725,7 @@ def maxshape(self): # Size of self.__next_chunk.data along self.iter_axis is not accurate for maxshape because it is just a # chunk. So try to set maxshape along the dimension self.iter_axis based on the shape of self.data if # possible. Otherwise, use None to represent an unlimited size - if hasattr(self.data, '__len__') and self.iter_axis == 0: + if hasattr(self.data, "__len__") and self.iter_axis == 0: # special case of 1-D array self.__maxshape[0] = len(self.data) else: @@ -712,12 +751,22 @@ class DataChunk: Class used to describe a data chunk. Used in DataChunkIterator. """ - @docval({'name': 'data', 'type': np.ndarray, - 'doc': 'Numpy array with the data value(s) of the chunk', 'default': None}, - {'name': 'selection', 'type': None, - 'doc': 'Numpy index tuple describing the location of the chunk', 'default': None}) + @docval( + { + "name": "data", + "type": np.ndarray, + "doc": "Numpy array with the data value(s) of the chunk", + "default": None, + }, + { + "name": "selection", + "type": None, + "doc": "Numpy index tuple describing the location of the chunk", + "default": None, + }, + ) def __init__(self, **kwargs): - self.data, self.selection = getargs('data', 'selection', kwargs) + self.data, self.selection = getargs("data", "selection", kwargs) def __len__(self): """Get the number of values in the data chunk""" @@ -731,20 +780,20 @@ def __getattr__(self, attr): return getattr(self.data, attr) def __copy__(self): - newobj = DataChunk(data=self.data, - selection=self.selection) + newobj = DataChunk(data=self.data, selection=self.selection) return newobj def __deepcopy__(self, memo): - result = DataChunk(data=copy.deepcopy(self.data), - selection=copy.deepcopy(self.selection)) + result = DataChunk( + data=copy.deepcopy(self.data), + selection=copy.deepcopy(self.selection), + ) memo[id(self)] = result return result def astype(self, dtype): """Get a new DataChunk with the self.data converted to the given type""" - return DataChunk(data=self.data.astype(dtype), - selection=self.selection) + return DataChunk(data=self.data.astype(dtype), selection=self.selection) @property def dtype(self): @@ -765,28 +814,31 @@ def get_min_bounds(self): """ if isinstance(self.selection, tuple): # Determine the minimum array dimensions to fit the chunk selection - max_bounds = tuple([x.stop or 0 if isinstance(x, slice) else x+1 for x in self.selection]) + max_bounds = tuple([x.stop or 0 if isinstance(x, slice) else x + 1 for x in self.selection]) elif isinstance(self.selection, int): - max_bounds = (self.selection+1, ) + max_bounds = (self.selection + 1,) elif isinstance(self.selection, slice): - max_bounds = (self.selection.stop or 0, ) + max_bounds = (self.selection.stop or 0,) else: # Note: Technically any numpy index tuple would be allowed, but h5py is not as general and this case # only implements the selections supported by h5py. We could add more cases to support a # broader range of valid numpy selection types - msg = ("Chunk selection %s must be a single int, single slice, or tuple of slices " - "and/or integers") % str(self.selection) + msg = "Chunk selection %s must be a single int, single slice, or tuple of slices and/or integers" % str( + self.selection + ) raise TypeError(msg) return max_bounds -def assertEqualShape(data1, - data2, - axes1=None, - axes2=None, - name1=None, - name2=None, - ignore_undetermined=True): +def assertEqualShape( + data1, + data2, + axes1=None, + axes2=None, + name1=None, + name2=None, + ignore_undetermined=True, +): """ Ensure that the shape of data1 and data2 match along the given dimensions @@ -826,26 +878,32 @@ def assertEqualShape(data1, # 1) Check the number of dimensions of the arrays if (response.axes1 is None and response.axes2 is None) and num_dims_1 != num_dims_2: response.result = False - response.error = 'NUM_DIMS_ERROR' + response.error = "NUM_DIMS_ERROR" response.message = response.SHAPE_ERROR[response.error] response.message += " %s is %sD and %s is %sD" % (n1, num_dims_1, n2, num_dims_2) # 2) Check that we have the same number of dimensions to compare on both arrays elif len(response.axes1) != len(response.axes2): response.result = False - response.error = 'NUM_AXES_ERROR' + response.error = "NUM_AXES_ERROR" response.message = response.SHAPE_ERROR[response.error] response.message += " Cannot compare axes %s with %s" % (str(response.axes1), str(response.axes2)) # 3) Check that the datasets have sufficient number of dimensions elif np.max(response.axes1) >= num_dims_1 or np.max(response.axes2) >= num_dims_2: response.result = False - response.error = 'AXIS_OUT_OF_BOUNDS' + response.error = "AXIS_OUT_OF_BOUNDS" response.message = response.SHAPE_ERROR[response.error] if np.max(response.axes1) >= num_dims_1: - response.message += "Insufficient number of dimensions for %s -- Expected %i found %i" % \ - (n1, np.max(response.axes1) + 1, num_dims_1) + response.message += "Insufficient number of dimensions for %s -- Expected %i found %i" % ( + n1, + np.max(response.axes1) + 1, + num_dims_1, + ) elif np.max(response.axes2) >= num_dims_2: - response.message += "Insufficient number of dimensions for %s -- Expected %i found %i" % \ - (n2, np.max(response.axes2) + 1, num_dims_2) + response.message += "Insufficient number of dimensions for %s -- Expected %i found %i" % ( + n2, + np.max(response.axes2) + 1, + num_dims_2, + ) # 4) Compare the length of the dimensions we should validate else: unmatched = [] @@ -868,15 +926,16 @@ def assertEqualShape(data1, response.message += " Ignored undetermined axes %s" % str(response.ignored) else: response.result = False - response.error = 'AXIS_LEN_ERROR' + response.error = "AXIS_LEN_ERROR" response.message = response.SHAPE_ERROR[response.error] - response.message += "Axes %s with size %s of %s did not match dimensions %s with sizes %s of %s." % \ - (str([un[0] for un in response.unmatched]), - str([response.shape1[un[0]] for un in response.unmatched]), - n1, - str([un[1] for un in response.unmatched]), - str([response.shape2[un[1]] for un in response.unmatched]), - n2) + response.message += "Axes %s with size %s of %s did not match dimensions %s with sizes %s of %s." % ( + str([un[0] for un in response.unmatched]), + str([response.shape1[un[0]] for un in response.unmatched]), + n1, + str([un[1] for un in response.unmatched]), + str([response.shape2[un[1]] for un in response.unmatched]), + n2, + ) if len(response.ignored) > 0: response.message += " Ignored undetermined axes %s" % str(response.ignored) return response @@ -892,49 +951,106 @@ class ShapeValidatorResult: :ivar message: Message indicating the result of the matching procedure :type messaage: str, None """ - SHAPE_ERROR = {None: 'All required axes matched', - 'NUM_DIMS_ERROR': 'Unequal number of dimensions.', - 'NUM_AXES_ERROR': "Unequal number of axes for comparison.", - 'AXIS_OUT_OF_BOUNDS': "Axis index for comparison out of bounds.", - 'AXIS_LEN_ERROR': "Unequal length of axes."} + + SHAPE_ERROR = { + None: "All required axes matched", + "NUM_DIMS_ERROR": "Unequal number of dimensions.", + "NUM_AXES_ERROR": "Unequal number of axes for comparison.", + "AXIS_OUT_OF_BOUNDS": "Axis index for comparison out of bounds.", + "AXIS_LEN_ERROR": "Unequal length of axes.", + } """ Dict where the Keys are the type of errors that may have occurred during shape comparison and the values are strings with default error messages for the type. """ - @docval({'name': 'result', 'type': bool, 'doc': 'Result of the shape validation', 'default': False}, - {'name': 'message', 'type': str, - 'doc': 'Message describing the result of the shape validation', 'default': None}, - {'name': 'ignored', 'type': tuple, - 'doc': 'Axes that have been ignored in the validaton process', 'default': tuple(), 'shape': (None,)}, - {'name': 'unmatched', 'type': tuple, - 'doc': 'List of axes that did not match during shape validation', 'default': tuple(), 'shape': (None,)}, - {'name': 'error', 'type': str, 'doc': 'Error that may have occurred. One of ERROR_TYPE', 'default': None}, - {'name': 'shape1', 'type': tuple, - 'doc': 'Shape of the first array for comparison', 'default': tuple(), 'shape': (None,)}, - {'name': 'shape2', 'type': tuple, - 'doc': 'Shape of the second array for comparison', 'default': tuple(), 'shape': (None,)}, - {'name': 'axes1', 'type': tuple, - 'doc': 'Axes for the first array that should match', 'default': tuple(), 'shape': (None,)}, - {'name': 'axes2', 'type': tuple, - 'doc': 'Axes for the second array that should match', 'default': tuple(), 'shape': (None,)}, - ) + @docval( + { + "name": "result", + "type": bool, + "doc": "Result of the shape validation", + "default": False, + }, + { + "name": "message", + "type": str, + "doc": "Message describing the result of the shape validation", + "default": None, + }, + { + "name": "ignored", + "type": tuple, + "doc": "Axes that have been ignored in the validaton process", + "default": tuple(), + "shape": (None,), + }, + { + "name": "unmatched", + "type": tuple, + "doc": "List of axes that did not match during shape validation", + "default": tuple(), + "shape": (None,), + }, + { + "name": "error", + "type": str, + "doc": "Error that may have occurred. One of ERROR_TYPE", + "default": None, + }, + { + "name": "shape1", + "type": tuple, + "doc": "Shape of the first array for comparison", + "default": tuple(), + "shape": (None,), + }, + { + "name": "shape2", + "type": tuple, + "doc": "Shape of the second array for comparison", + "default": tuple(), + "shape": (None,), + }, + { + "name": "axes1", + "type": tuple, + "doc": "Axes for the first array that should match", + "default": tuple(), + "shape": (None,), + }, + { + "name": "axes2", + "type": tuple, + "doc": "Axes for the second array that should match", + "default": tuple(), + "shape": (None,), + }, + ) def __init__(self, **kwargs): - self.result, self.message, self.ignored, self.unmatched, \ - self.error, self.shape1, self.shape2, self.axes1, self.axes2 = getargs( - 'result', 'message', 'ignored', 'unmatched', 'error', 'shape1', 'shape2', 'axes1', 'axes2', kwargs) + self.result, self.message, self.ignored = getargs("result", "message", "ignored", kwargs) + self.unmatched, self.error = getargs("unmatched", "error", kwargs) + self.shape1, self.shape2, self.axes1, self.axes2 = getargs("shape1", "shape2", "axes1", "axes2", kwargs) def __setattr__(self, key, value): """ Overwrite to ensure that, e.g., error_message is not set to an illegal value. """ - if key == 'error': + if key == "error": if value not in self.SHAPE_ERROR.keys(): - raise ValueError("Illegal error type. Error must be one of ShapeValidatorResult.SHAPE_ERROR: %s" - % str(self.SHAPE_ERROR)) + raise ValueError( + "Illegal error type. Error must be one of ShapeValidatorResult.SHAPE_ERROR: %s" + % str(self.SHAPE_ERROR) + ) else: super().__setattr__(key, value) - elif key in ['shape1', 'shape2', 'axes1', 'axes2', 'ignored', 'unmatched']: # Make sure we sore tuples + elif key in [ + "shape1", + "shape2", + "axes1", + "axes2", + "ignored", + "unmatched", + ]: # Make sure we sore tuples super().__setattr__(key, tuple(value)) else: super().__setattr__(key, value) @@ -943,32 +1059,40 @@ def __getattr__(self, item): """ Overwrite to allow dynamic retrieval of the default message """ - if item == 'default_message': + if item == "default_message": return self.SHAPE_ERROR[self.error] return self.__getattribute__(item) -@docval_macro('data') +@docval_macro("data") class DataIO: """ Base class for wrapping data arrays for I/O. Derived classes of DataIO are typically used to pass dataset-specific I/O parameters to the particular HDMFIO backend. """ - @docval({'name': 'data', - 'type': 'array_data', - 'doc': 'the data to be written', - 'default': None}, - {'name': 'dtype', - 'type': (type, np.dtype), - 'doc': 'the data type of the dataset. Not used if data is specified.', - 'default': None}, - {'name': 'shape', - 'type': tuple, - 'doc': 'the shape of the dataset. Not used if data is specified.', - 'default': None}) + @docval( + { + "name": "data", + "type": "array_data", + "doc": "the data to be written", + "default": None, + }, + { + "name": "dtype", + "type": (type, np.dtype), + "doc": "the data type of the dataset. Not used if data is specified.", + "default": None, + }, + { + "name": "shape", + "type": tuple, + "doc": "the shape of the dataset. Not used if data is specified.", + "default": None, + }, + ) def __init__(self, **kwargs): - data, dtype, shape = popargs('data', 'dtype', 'shape', kwargs) + data, dtype, shape = popargs("data", "dtype", "shape", kwargs) if data is None: if (dtype is None) ^ (shape is None): raise ValueError("Must specify 'dtype' and 'shape' if not specifying 'data'") @@ -1063,7 +1187,7 @@ def __bool__(self): def __getattr__(self, attr): """Delegate attribute lookup to data object""" - if attr == '__array_struct__' and not self.valid: + if attr == "__array_struct__" and not self.valid: # np.array() checks __array__ or __array_struct__ attribute dep. on numpy version raise InvalidDataIOError("Cannot convert data to array. Data is not valid.") if not self.valid: @@ -1085,7 +1209,7 @@ def __array__(self): """ if not self.valid: raise InvalidDataIOError("Cannot convert data to array. Data is not valid.") - if hasattr(self.data, '__array__'): + if hasattr(self.data, "__array__"): return self.data.__array__() elif isinstance(self.data, DataChunkIterator): raise NotImplementedError("Conversion of DataChunkIterator to array not supported") diff --git a/src/hdmf/monitor.py b/src/hdmf/monitor.py index 823ccf72d..737f89381 100644 --- a/src/hdmf/monitor.py +++ b/src/hdmf/monitor.py @@ -1,6 +1,6 @@ from abc import ABCMeta, abstractmethod -from .data_utils import AbstractDataChunkIterator, DataChunkIterator, DataChunk +from .data_utils import AbstractDataChunkIterator, DataChunk, DataChunkIterator from .utils import docval, getargs @@ -9,12 +9,11 @@ class NotYetExhausted(Exception): class DataChunkProcessor(AbstractDataChunkIterator, metaclass=ABCMeta): - - @docval({'name': 'data', 'type': DataChunkIterator, 'doc': 'the DataChunkIterator to analyze'}) + @docval({"name": "data", "type": DataChunkIterator, "doc": "the DataChunkIterator to analyze"}) def __init__(self, **kwargs): """Initialize the DataChunkIterator""" # Get the user parameters - self.__dci = getargs('data', kwargs) + self.__dci = getargs("data", kwargs) def __next__(self): try: @@ -35,39 +34,38 @@ def recommended_data_shape(self): return self.__dci.recommended_data_shape() def get_final_result(self, **kwargs): - ''' Return the result of processing data fed by this DataChunkIterator ''' + """Return the result of processing data fed by this DataChunkIterator""" if not self.__done: raise NotYetExhausted() return self.compute_final_result() @abstractmethod - @docval({'name': 'data_chunk', 'type': DataChunk, 'doc': 'a chunk to process'}) + @docval({"name": "data_chunk", "type": DataChunk, "doc": "a chunk to process"}) def process_data_chunk(self, **kwargs): - ''' This method should take in a DataChunk, - and process it. - ''' + """This method should take in a DataChunk, + and process it. + """ pass @abstractmethod - @docval(returns='the result of processing this stream') + @docval(returns="the result of processing this stream") def compute_final_result(self, **kwargs): - ''' Return the result of processing this stream - Should raise NotYetExhaused exception - ''' + """Return the result of processing this stream + Should raise NotYetExhaused exception + """ pass class NumSampleCounter(DataChunkProcessor): - def __init__(self, **kwargs): super().__init__(**kwargs) self.__sample_count = 0 - @docval({'name': 'data_chunk', 'type': DataChunk, 'doc': 'a chunk to process'}) + @docval({"name": "data_chunk", "type": DataChunk, "doc": "a chunk to process"}) def process_data_chunk(self, **kwargs): - dc = getargs('data_chunk', kwargs) + dc = getargs("data_chunk", kwargs) self.__sample_count += len(dc) - @docval(returns='the result of processing this stream') + @docval(returns="the result of processing this stream") def compute_final_result(self, **kwargs): return self.__sample_count diff --git a/src/hdmf/query.py b/src/hdmf/query.py index 835b295c5..9eb6bdaa7 100644 --- a/src/hdmf/query.py +++ b/src/hdmf/query.py @@ -3,17 +3,17 @@ import numpy as np from .array import Array -from .utils import ExtenderMeta, docval_macro, docval, getargs +from .utils import ExtenderMeta, docval, docval_macro, getargs class Query(metaclass=ExtenderMeta): __operations__ = ( - '__lt__', - '__gt__', - '__le__', - '__ge__', - '__eq__', - '__ne__', + "__lt__", + "__gt__", + "__le__", + "__ge__", + "__eq__", + "__ne__", ) @classmethod @@ -26,8 +26,12 @@ def __make_operators(cls, name, bases, classdict): if not isinstance(cls.__operations__, tuple): raise TypeError("'__operations__' must be of type tuple") # add any new operations - if len(bases) and 'Query' in globals() and issubclass(bases[-1], Query) \ - and bases[-1].__operations__ is not cls.__operations__: + if ( + len(bases) + and "Query" in globals() + and issubclass(bases[-1], Query) + and bases[-1].__operations__ is not cls.__operations__ + ): new_operations = list(cls.__operations__) new_operations[0:0] = bases[-1].__operations__ cls.__operations__ = tuple(new_operations) @@ -42,9 +46,9 @@ def __init__(self, obj, op, arg): self.collapsed = None self.expanded = None - @docval({'name': 'expand', 'type': bool, 'help': 'whether or not to expand result', 'default': True}) + @docval({"name": "expand", "type": bool, "help": "whether or not to expand result", "default": True}) def evaluate(self, **kwargs): - expand = getargs('expand', kwargs) + expand = getargs("expand", kwargs) if expand: if self.expanded is None: self.expanded = self.__evalhelper() @@ -92,15 +96,15 @@ def __contains__(self, other): return NotImplemented -@docval_macro('array_data') +@docval_macro("array_data") class HDMFDataset(metaclass=ExtenderMeta): __operations__ = ( - '__lt__', - '__gt__', - '__le__', - '__ge__', - '__eq__', - '__ne__', + "__lt__", + "__gt__", + "__le__", + "__ge__", + "__eq__", + "__ne__", ) @classmethod @@ -108,7 +112,7 @@ def __build_operation(cls, op): def __func(self, arg): return Query(self, op, arg) - setattr(__func, '__name__', op) + setattr(__func, "__name__", op) return __func @ExtenderMeta.pre_init @@ -116,8 +120,12 @@ def __make_operators(cls, name, bases, classdict): if not isinstance(cls.__operations__, tuple): raise TypeError("'__operations__' must be of type tuple") # add any new operations - if len(bases) and 'Query' in globals() and issubclass(bases[-1], Query) \ - and bases[-1].__operations__ is not cls.__operations__: + if ( + len(bases) + and "Query" in globals() + and issubclass(bases[-1], Query) + and bases[-1].__operations__ is not cls.__operations__ + ): new_operations = list(cls.__operations__) new_operations[0:0] = bases[-1].__operations__ cls.__operations__ = tuple(new_operations) @@ -138,10 +146,10 @@ def __getitem__(self, key): idx = self.__evaluate_key(key) return self.dataset[idx] - @docval({'name': 'dataset', 'type': ('array_data', Array), 'doc': 'the HDF5 file lazily evaluate'}) + @docval({"name": "dataset", "type": ("array_data", Array), "doc": "the HDF5 file lazily evaluate"}) def __init__(self, **kwargs): super().__init__() - self.__dataset = getargs('dataset', kwargs) + self.__dataset = getargs("dataset", kwargs) @property def dataset(self): diff --git a/src/hdmf/region.py b/src/hdmf/region.py index 9feeba401..6e5566ef9 100644 --- a/src/hdmf/region.py +++ b/src/hdmf/region.py @@ -6,17 +6,19 @@ class RegionSlicer(DataRegion, metaclass=ABCMeta): - ''' + """ A abstract base class to control getting using a region Subclasses must implement `__getitem__` and `__len__` - ''' + """ - @docval({'name': 'target', 'type': None, 'doc': 'the target to slice'}, - {'name': 'slice', 'type': None, 'doc': 'the region to slice'}) + @docval( + {"name": "target", "type": None, "doc": "the target to slice"}, + {"name": "slice", "type": None, "doc": "the region to slice"}, + ) def __init__(self, **kwargs): - self.__target = getargs('target', kwargs) - self.__slice = getargs('slice', kwargs) + self.__target = getargs("target", kwargs) + self.__slice = getargs("slice", kwargs) @property def data(self): @@ -54,10 +56,12 @@ def __len__(self): class ListSlicer(RegionSlicer): """Implementation of RegionSlicer for slicing Lists and Data""" - @docval({'name': 'dataset', 'type': (list, tuple, Data), 'doc': 'the dataset to slice'}, - {'name': 'region', 'type': (list, tuple, slice), 'doc': 'the region reference to use to slice'}) + @docval( + {"name": "dataset", "type": (list, tuple, Data), "doc": "the dataset to slice"}, + {"name": "region", "type": (list, tuple, slice), "doc": "the region reference to use to slice"}, + ) def __init__(self, **kwargs): - self.__dataset, self.__region = getargs('dataset', 'region', kwargs) + self.__dataset, self.__region = getargs("dataset", "region", kwargs) super().__init__(self.__dataset, self.__region) if isinstance(self.__region, slice): self.__getter = itemgetter(self.__region) @@ -70,7 +74,7 @@ def __read_region(self): """ Internal helper function used to define self._read """ - if not hasattr(self, '_read'): + if not hasattr(self, "_read"): self._read = self.__getter(self.__dataset) del self.__getter diff --git a/src/hdmf/spec/__init__.py b/src/hdmf/spec/__init__.py index 09ad6d073..b520c192b 100644 --- a/src/hdmf/spec/__init__.py +++ b/src/hdmf/spec/__init__.py @@ -1,5 +1,14 @@ from .catalog import SpecCatalog from .namespace import NamespaceCatalog, SpecNamespace, SpecReader -from .spec import (AttributeSpec, DatasetSpec, DtypeHelper, DtypeSpec, GroupSpec, LinkSpec, - NAME_WILDCARD, RefSpec, Spec) +from .spec import ( + NAME_WILDCARD, + AttributeSpec, + DatasetSpec, + DtypeHelper, + DtypeSpec, + GroupSpec, + LinkSpec, + RefSpec, + Spec, +) from .write import NamespaceBuilder, SpecWriter, export_spec diff --git a/src/hdmf/spec/catalog.py b/src/hdmf/spec/catalog.py index 636eb3bc0..2df8ed6aa 100644 --- a/src/hdmf/spec/catalog.py +++ b/src/hdmf/spec/catalog.py @@ -1,14 +1,13 @@ import copy from collections import OrderedDict -from .spec import BaseStorageSpec, GroupSpec from ..utils import docval, getargs +from .spec import BaseStorageSpec, GroupSpec class SpecCatalog: - def __init__(self): - ''' + """ Create a new catalog for storing specifications ** Private Instance Variables ** @@ -20,24 +19,34 @@ def __init__(self): NOTE: Always use SpecCatalog.get_hierarchy(...) to retrieve the hierarchy as this dictionary is used like a cache, i.e., to avoid repeated calculation of the hierarchy but the contents are computed on first request by SpecCatalog.get_hierarchy(...) - ''' + """ self.__specs = OrderedDict() self.__parent_types = dict() self.__hierarchy = dict() self.__spec_source_files = dict() - @docval({'name': 'spec', 'type': BaseStorageSpec, 'doc': 'a Spec object'}, - {'name': 'source_file', 'type': str, - 'doc': 'path to the source file from which the spec was loaded', 'default': None}) + @docval( + { + "name": "spec", + "type": BaseStorageSpec, + "doc": "a Spec object", + }, + { + "name": "source_file", + "type": str, + "doc": "path to the source file from which the spec was loaded", + "default": None, + }, + ) def register_spec(self, **kwargs): - ''' + """ Associate a specified object type with a specification - ''' - spec, source_file = getargs('spec', 'source_file', kwargs) + """ + spec, source_file = getargs("spec", "source_file", kwargs) ndt = spec.data_type_inc ndt_def = spec.data_type_def if ndt_def is None: - raise ValueError('cannot register spec that has no data_type_def') + raise ValueError("cannot register spec that has no data_type_def") if ndt_def != ndt: self.__parent_types[ndt_def] = ndt type_name = ndt_def if ndt_def is not None else ndt @@ -47,46 +56,61 @@ def register_spec(self, **kwargs): self.__specs[type_name] = spec self.__spec_source_files[type_name] = source_file - @docval({'name': 'data_type', 'type': str, 'doc': 'the data_type to get the Spec for'}, - returns="the specification for writing the given object type to HDF5 ", rtype='Spec') + @docval( + {"name": "data_type", "type": str, "doc": "the data_type to get the Spec for"}, + returns="the specification for writing the given object type to HDF5 ", + rtype="Spec", + ) def get_spec(self, **kwargs): - ''' + """ Get the Spec object for the given type - ''' - data_type = getargs('data_type', kwargs) + """ + data_type = getargs("data_type", kwargs) return self.__specs.get(data_type, None) @docval(rtype=tuple) def get_registered_types(self, **kwargs): - ''' + """ Return all registered specifications - ''' + """ # kwargs is not used here but is used by docval return tuple(self.__specs.keys()) - @docval({'name': 'data_type', 'type': str, 'doc': 'the data_type of the spec to get the source file for'}, - returns="the path to source specification file from which the spec was originally loaded or None ", - rtype='str') + @docval( + {"name": "data_type", "type": str, "doc": "the data_type of the spec to get the source file for"}, + returns="the path to source specification file from which the spec was originally loaded or None ", + rtype="str", + ) def get_spec_source_file(self, **kwargs): - ''' + """ Return the path to the source file from which the spec for the given type was loaded from. None is returned if no file path is available for the spec. Note: The spec in the file may not be identical to the object in case the spec is modified after load. - ''' - data_type = getargs('data_type', kwargs) + """ + data_type = getargs("data_type", kwargs) return self.__spec_source_files.get(data_type, None) - @docval({'name': 'spec', 'type': BaseStorageSpec, 'doc': 'the Spec object to register'}, - {'name': 'source_file', - 'type': str, - 'doc': 'path to the source file from which the spec was loaded', 'default': None}, - rtype=tuple, returns='the types that were registered with this spec') + @docval( + { + "name": "spec", + "type": BaseStorageSpec, + "doc": "the Spec object to register", + }, + { + "name": "source_file", + "type": str, + "doc": "path to the source file from which the spec was loaded", + "default": None, + }, + rtype=tuple, + returns="the types that were registered with this spec", + ) def auto_register(self, **kwargs): - ''' + """ Register this specification and all sub-specification using data_type as object type name - ''' - spec, source_file = getargs('spec', 'source_file', kwargs) + """ + spec, source_file = getargs("spec", "source_file", kwargs) ndt = spec.data_type_def ret = list() if ndt is not None: @@ -102,10 +126,11 @@ def auto_register(self, **kwargs): ret.extend(self.auto_register(group_spec, source_file)) return tuple(ret) - @docval({'name': 'data_type', 'type': (str, type), - 'doc': 'the data_type to get the hierarchy of'}, - returns="Tuple of strings with the names of the types the given data_type inherits from.", - rtype=tuple) + @docval( + {"name": "data_type", "type": (str, type), "doc": "the data_type to get the hierarchy of"}, + returns="Tuple of strings with the names of the types the given data_type inherits from.", + rtype=tuple, + ) def get_hierarchy(self, **kwargs): """ For a given type get the type inheritance hierarchy for that type. @@ -113,7 +138,7 @@ def get_hierarchy(self, **kwargs): E.g., if we have a type MyContainer that inherits from BaseContainer then the result will be a tuple with the strings ('MyContainer', 'BaseContainer') """ - data_type = getargs('data_type', kwargs) + data_type = getargs("data_type", kwargs) if isinstance(data_type, type): data_type = data_type.__name__ ret = self.__hierarchy.get(data_type) @@ -132,8 +157,10 @@ def get_hierarchy(self, **kwargs): tmp_hier = tmp_hier[1:] return tuple(ret) - @docval(returns="Hierarchically nested OrderedDict with the hierarchy of all the types", - rtype=OrderedDict) + @docval( + returns="Hierarchically nested OrderedDict with the hierarchy of all the types", + rtype=OrderedDict, + ) def get_full_hierarchy(self): """ Get the complete hierarchy of all types. The function attempts to sort types by name using @@ -159,13 +186,21 @@ def get_type_hierarchy(data_type, spec_catalog): return type_hierarchy - @docval({'name': 'data_type', 'type': (str, type), - 'doc': 'the data_type to get the subtypes for'}, - {'name': 'recursive', 'type': bool, - 'doc': 'recursively get all subtypes. Set to False to only get the direct subtypes', - 'default': True}, - returns="Tuple of strings with the names of all types of the given data_type.", - rtype=tuple) + @docval( + { + "name": "data_type", + "type": (str, type), + "doc": "the data_type to get the subtypes for", + }, + { + "name": "recursive", + "type": bool, + "doc": "recursively get all subtypes. Set to False to only get the direct subtypes", + "default": True, + }, + returns="Tuple of strings with the names of all types of the given data_type.", + rtype=tuple, + ) def get_subtypes(self, **kwargs): """ For a given data type recursively find all the subtypes that inherit from it. @@ -179,7 +214,7 @@ def get_subtypes(self, **kwargs): In this case, the subtypes of BaseContainer would be (AContainer, ADContainer, BContainer), the subtypes of AContainer would be (ADContainer), and the subtypes of BContainer would be empty (). """ - data_type, recursive = getargs('data_type', 'recursive', kwargs) + data_type, recursive = getargs("data_type", "recursive", kwargs) curr_spec = self.get_spec(data_type) if isinstance(curr_spec, BaseStorageSpec): # Only BaseStorageSpec have data_type_inc/def keys subtypes = [] diff --git a/src/hdmf/spec/namespace.py b/src/hdmf/spec/namespace.py index 73c41a1d8..f5af0355a 100644 --- a/src/hdmf/spec/namespace.py +++ b/src/hdmf/spec/namespace.py @@ -1,5 +1,4 @@ import os.path -import ruamel.yaml as yaml import string from abc import ABCMeta, abstractmethod from collections import OrderedDict @@ -7,23 +6,64 @@ from datetime import datetime from warnings import warn +import ruamel.yaml as yaml + +from ..utils import docval, get_docval, getargs, popargs from .catalog import SpecCatalog from .spec import DatasetSpec, GroupSpec -from ..utils import docval, getargs, popargs, get_docval _namespace_args = [ - {'name': 'doc', 'type': str, 'doc': 'a description about what this namespace represents'}, - {'name': 'name', 'type': str, 'doc': 'the name of this namespace'}, - {'name': 'schema', 'type': list, 'doc': 'location of schema specification files or other Namespaces'}, - {'name': 'full_name', 'type': str, 'doc': 'extended full name of this namespace', 'default': None}, - {'name': 'version', 'type': (str, tuple, list), 'doc': 'Version number of the namespace', 'default': None}, - {'name': 'date', 'type': (datetime, str), - 'doc': "Date last modified or released. Formatting is %Y-%m-%d %H:%M:%S, e.g, 2017-04-25 17:14:13", - 'default': None}, - {'name': 'author', 'type': (str, list), 'doc': 'Author or list of authors.', 'default': None}, - {'name': 'contact', 'type': (str, list), - 'doc': 'List of emails. Ordering should be the same as for author', 'default': None}, - {'name': 'catalog', 'type': SpecCatalog, 'doc': 'The SpecCatalog object for this SpecNamespace', 'default': None} + { + "name": "doc", + "type": str, + "doc": "a description about what this namespace represents", + }, + { + "name": "name", + "type": str, + "doc": "the name of this namespace", + }, + { + "name": "schema", + "type": list, + "doc": "location of schema specification files or other Namespaces", + }, + { + "name": "full_name", + "type": str, + "doc": "extended full name of this namespace", + "default": None, + }, + { + "name": "version", + "type": (str, tuple, list), + "doc": "Version number of the namespace", + "default": None, + }, + { + "name": "date", + "type": (datetime, str), + "doc": "Date last modified or released. Formatting is %Y-%m-%d %H:%M:%S, e.g, 2017-04-25 17:14:13", + "default": None, + }, + { + "name": "author", + "type": (str, list), + "doc": "Author or list of authors.", + "default": None, + }, + { + "name": "contact", + "type": (str, list), + "doc": "List of emails. Ordering should be the same as for author", + "default": None, + }, + { + "name": "catalog", + "type": SpecCatalog, + "doc": "The SpecCatalog object for this SpecNamespace", + "default": None, + }, ] @@ -32,22 +72,22 @@ class SpecNamespace(dict): A namespace for specifications """ - __types_key = 'data_types' + __types_key = "data_types" UNVERSIONED = None # value representing missing version @docval(*_namespace_args) def __init__(self, **kwargs): - doc, full_name, name, version, date, author, contact, schema, catalog = \ - popargs('doc', 'full_name', 'name', 'version', 'date', 'author', 'contact', 'schema', 'catalog', kwargs) + doc, full_name, name, version, date = popargs("doc", "full_name", "name", "version", "date", kwargs) + author, contact, schema, catalog = popargs("author", "contact", "schema", "catalog", kwargs) super().__init__() - self['doc'] = doc - self['schema'] = schema + self["doc"] = doc + self["schema"] = schema if any(c in string.whitespace for c in name): raise ValueError("'name' must not contain any whitespace") - self['name'] = name + self["name"] = name if full_name is not None: - self['full_name'] = full_name + self["full_name"] = full_name if version == str(SpecNamespace.UNVERSIONED): # the unversioned version may be written to file as a string and read from file as a string warn("Loaded namespace '%s' is unversioned. Please notify the extension author." % name) @@ -55,40 +95,42 @@ def __init__(self, **kwargs): if version is None: # version is required on write -- see YAMLSpecWriter.write_namespace -- but can be None on read in order to # be able to read older files with extensions that are missing the version key. - warn(("Loaded namespace '%s' is missing the required key 'version'. Version will be set to '%s'. " - "Please notify the extension author.") % (name, SpecNamespace.UNVERSIONED)) + warn( + "Loaded namespace '%s' is missing the required key 'version'. Version" + " will be set to '%s'. Please notify the extension author." % (name, SpecNamespace.UNVERSIONED) + ) version = SpecNamespace.UNVERSIONED - self['version'] = version + self["version"] = version if date is not None: - self['date'] = date + self["date"] = date if author is not None: - self['author'] = author + self["author"] = author if contact is not None: - self['contact'] = contact + self["contact"] = contact self.__catalog = catalog if catalog is not None else SpecCatalog() @classmethod def types_key(cls): - ''' Get the key used for specifying types to include from a file or namespace + """Get the key used for specifying types to include from a file or namespace Override this method to use a different name for 'data_types' - ''' + """ return cls.__types_key @property def full_name(self): """String with full name or None""" - return self.get('full_name', None) + return self.get("full_name", None) @property def contact(self): """String or list of strings with the contacts or None""" - return self.get('contact', None) + return self.get("contact", None) @property def author(self): """String or list of strings with the authors or None""" - return self.get('author', None) + return self.get("author", None) @property def version(self): @@ -96,36 +138,39 @@ def version(self): String, list, or tuple with the version or SpecNamespace.UNVERSIONED if the version is missing or empty """ - return self.get('version', None) or SpecNamespace.UNVERSIONED + return self.get("version", None) or SpecNamespace.UNVERSIONED @property def date(self): """Date last modified or released. :return: datetime object, string, or None""" - return self.get('date', None) + return self.get("date", None) @property def name(self): """String with short name or None""" - return self.get('name', None) + return self.get("name", None) @property def doc(self): - return self['doc'] + return self["doc"] @property def schema(self): - return self['schema'] + return self["schema"] def get_source_files(self): """ Get the list of names of the source files included the schema of the namespace """ - return [item['source'] for item in self.schema if 'source' in item] + return [item["source"] for item in self.schema if "source" in item] - @docval({'name': 'sourcefile', 'type': str, 'doc': 'Name of the source file'}, - returns='Dict with the source file documentation', rtype=dict) + @docval( + {"name": "sourcefile", "type": str, "doc": "Name of the source file"}, + returns="Dict with the source file documentation", + rtype=dict, + ) def get_source_description(self, sourcefile): """ Get the description of a source file as described in the namespace. The result is a @@ -133,7 +178,7 @@ def get_source_description(self, sourcefile): imported from the source file """ for item in self.schema: - if item.get('source', None) == sourcefile: + if item.get("source", None) == sourcefile: return item @property @@ -141,10 +186,10 @@ def catalog(self): """The SpecCatalog containing all the Specs""" return self.__catalog - @docval({'name': 'data_type', 'type': (str, type), 'doc': 'the data_type to get the spec for'}) + @docval({"name": "data_type", "type": (str, type), "doc": "the data_type to get the spec for"}) def get_spec(self, **kwargs): """Get the Spec object for the given data type""" - data_type = getargs('data_type', kwargs) + data_type = getargs("data_type", kwargs) spec = self.__catalog.get_spec(data_type) if spec is None: raise ValueError("No specification for '%s' in namespace '%s'" % (data_type, self.name)) @@ -155,28 +200,30 @@ def get_registered_types(self, **kwargs): """Get the available types in this namespace""" return self.__catalog.get_registered_types() - @docval({'name': 'data_type', 'type': (str, type), 'doc': 'the data_type to get the hierarchy of'}, - returns="a tuple with the type hierarchy", rtype=tuple) + @docval( + {"name": "data_type", "type": (str, type), "doc": "the data_type to get the hierarchy of"}, + returns="a tuple with the type hierarchy", + rtype=tuple, + ) def get_hierarchy(self, **kwargs): - ''' Get the extension hierarchy for the given data_type in this namespace''' - data_type = getargs('data_type', kwargs) + """Get the extension hierarchy for the given data_type in this namespace""" + data_type = getargs("data_type", kwargs) return self.__catalog.get_hierarchy(data_type) @classmethod def build_namespace(cls, **spec_dict): kwargs = copy(spec_dict) try: - args = [kwargs.pop(x['name']) for x in get_docval(cls.__init__) if 'default' not in x] + args = [kwargs.pop(x["name"]) for x in get_docval(cls.__init__) if "default" not in x] except KeyError as e: raise KeyError("'%s' not found in %s" % (e.args[0], str(spec_dict))) return cls(*args, **kwargs) class SpecReader(metaclass=ABCMeta): - - @docval({'name': 'source', 'type': str, 'doc': 'the source from which this reader reads from'}) + @docval({"name": "source", "type": str, "doc": "the source from which this reader reads from"}) def __init__(self, **kwargs): - self.__source = getargs('source', kwargs) + self.__source = getargs("source", kwargs) @property def source(self): @@ -192,27 +239,26 @@ def read_namespace(self): class YAMLSpecReader(SpecReader): - - @docval({'name': 'indir', 'type': str, 'doc': 'the path spec files are relative to', 'default': '.'}) + @docval({"name": "indir", "type": str, "doc": "the path spec files are relative to", "default": "."}) def __init__(self, **kwargs): - super().__init__(source=kwargs['indir']) + super().__init__(source=kwargs["indir"]) def read_namespace(self, namespace_path): namespaces = None - with open(namespace_path, 'r') as stream: - yaml_obj = yaml.YAML(typ='safe', pure=True) + with open(namespace_path, "r") as stream: + yaml_obj = yaml.YAML(typ="safe", pure=True) d = yaml_obj.load(stream) - namespaces = d.get('namespaces') + namespaces = d.get("namespaces") if namespaces is None: raise ValueError("no 'namespaces' found in %s" % namespace_path) return namespaces def read_spec(self, spec_path): specs = None - with open(self.__get_spec_path(spec_path), 'r') as stream: - yaml_obj = yaml.YAML(typ='safe', pure=True) + with open(self.__get_spec_path(spec_path), "r") as stream: + yaml_obj = yaml.YAML(typ="safe", pure=True) specs = yaml_obj.load(stream) - if not ('datasets' in specs or 'groups' in specs): + if not ("datasets" in specs or "groups" in specs): raise ValueError("no 'groups' or 'datasets' found in %s" % spec_path) return specs @@ -223,19 +269,32 @@ def __get_spec_path(self, spec_path): class NamespaceCatalog: - - @docval({'name': 'group_spec_cls', 'type': type, - 'doc': 'the class to use for group specifications', 'default': GroupSpec}, - {'name': 'dataset_spec_cls', 'type': type, - 'doc': 'the class to use for dataset specifications', 'default': DatasetSpec}, - {'name': 'spec_namespace_cls', 'type': type, - 'doc': 'the class to use for specification namespaces', 'default': SpecNamespace}) + @docval( + { + "name": "group_spec_cls", + "type": type, + "doc": "the class to use for group specifications", + "default": GroupSpec, + }, + { + "name": "dataset_spec_cls", + "type": type, + "doc": "the class to use for dataset specifications", + "default": DatasetSpec, + }, + { + "name": "spec_namespace_cls", + "type": type, + "doc": "the class to use for specification namespaces", + "default": SpecNamespace, + }, + ) def __init__(self, **kwargs): """Create a catalog for storing multiple Namespaces""" self.__namespaces = OrderedDict() - self.__dataset_spec_cls = getargs('dataset_spec_cls', kwargs) - self.__group_spec_cls = getargs('group_spec_cls', kwargs) - self.__spec_namespace_cls = getargs('spec_namespace_cls', kwargs) + self.__dataset_spec_cls = getargs("dataset_spec_cls", kwargs) + self.__group_spec_cls = getargs("group_spec_cls", kwargs) + self.__spec_namespace_cls = getargs("spec_namespace_cls", kwargs) # keep track of all spec objects ever loaded, so we don't have # multiple object instances of a spec self.__loaded_specs = dict() @@ -245,9 +304,11 @@ def __init__(self, **kwargs): self._loaded_specs = self.__loaded_specs def __copy__(self): - ret = NamespaceCatalog(self.__group_spec_cls, - self.__dataset_spec_cls, - self.__spec_namespace_cls) + ret = NamespaceCatalog( + self.__group_spec_cls, + self.__dataset_spec_cls, + self.__spec_namespace_cls, + ) ret.__namespaces = copy(self.__namespaces) ret.__loaded_specs = copy(self.__loaded_specs) ret.__included_specs = copy(self.__included_specs) @@ -259,7 +320,7 @@ def merge(self, ns_catalog): self.add_namespace(name, namespace) @property - @docval(returns='a tuple of the available namespaces', rtype=tuple) + @docval(returns="a tuple of the available namespaces", rtype=tuple) def namespaces(self): """The namespaces in this NamespaceCatalog""" return tuple(self.__namespaces.keys()) @@ -279,11 +340,13 @@ def spec_namespace_cls(self): """The SpecNamespace class used in this NamespaceCatalog""" return self.__spec_namespace_cls - @docval({'name': 'name', 'type': str, 'doc': 'the name of this namespace'}, - {'name': 'namespace', 'type': SpecNamespace, 'doc': 'the SpecNamespace object'}) + @docval( + {"name": "name", "type": str, "doc": "the name of this namespace"}, + {"name": "namespace", "type": SpecNamespace, "doc": "the SpecNamespace object"}, + ) def add_namespace(self, **kwargs): """Add a namespace to this catalog""" - name, namespace = getargs('name', 'namespace', kwargs) + name, namespace = getargs("name", "namespace", kwargs) if name in self.__namespaces: raise KeyError("namespace '%s' already exists" % name) self.__namespaces[name] = namespace @@ -293,76 +356,92 @@ def add_namespace(self, **kwargs): # use dict with None values as ordered set because order of specs does matter self.__loaded_specs.setdefault(source, dict()).update({dt: None}) - @docval({'name': 'name', 'type': str, 'doc': 'the name of this namespace'}, - returns="the SpecNamespace with the given name", rtype=SpecNamespace) + @docval( + {"name": "name", "type": str, "doc": "the name of this namespace"}, + returns="the SpecNamespace with the given name", + rtype=SpecNamespace, + ) def get_namespace(self, **kwargs): """Get the a SpecNamespace""" - name = getargs('name', kwargs) + name = getargs("name", kwargs) ret = self.__namespaces.get(name) if ret is None: raise KeyError("'%s' not a namespace" % name) return ret - @docval({'name': 'namespace', 'type': str, 'doc': 'the name of the namespace'}, - {'name': 'data_type', 'type': (str, type), 'doc': 'the data_type to get the spec for'}, - returns="the specification for writing the given object type to HDF5 ", rtype='Spec') + @docval( + {"name": "namespace", "type": str, "doc": "the name of the namespace"}, + {"name": "data_type", "type": (str, type), "doc": "the data_type to get the spec for"}, + returns="the specification for writing the given object type to HDF5 ", + rtype="Spec", + ) def get_spec(self, **kwargs): - ''' + """ Get the Spec object for the given type from the given Namespace - ''' - namespace, data_type = getargs('namespace', 'data_type', kwargs) + """ + namespace, data_type = getargs("namespace", "data_type", kwargs) if namespace not in self.__namespaces: raise KeyError("'%s' not a namespace" % namespace) return self.__namespaces[namespace].get_spec(data_type) - @docval({'name': 'namespace', 'type': str, 'doc': 'the name of the namespace'}, - {'name': 'data_type', 'type': (str, type), 'doc': 'the data_type to get the spec for'}, - returns="a tuple with the type hierarchy", rtype=tuple) + @docval( + {"name": "namespace", "type": str, "doc": "the name of the namespace"}, + {"name": "data_type", "type": (str, type), "doc": "the data_type to get the spec for"}, + returns="a tuple with the type hierarchy", + rtype=tuple, + ) def get_hierarchy(self, **kwargs): - ''' + """ Get the type hierarchy for a given data_type in a given namespace - ''' - namespace, data_type = getargs('namespace', 'data_type', kwargs) + """ + namespace, data_type = getargs("namespace", "data_type", kwargs) spec_ns = self.__namespaces.get(namespace) if spec_ns is None: raise KeyError("'%s' not a namespace" % namespace) return spec_ns.get_hierarchy(data_type) - @docval({'name': 'namespace', 'type': str, 'doc': 'the name of the namespace containing the data_type'}, - {'name': 'data_type', 'type': str, 'doc': 'the data_type to check'}, - {'name': 'parent_data_type', 'type': str, 'doc': 'the potential parent data_type'}, - returns="True if *data_type* is a sub `data_type` of *parent_data_type*, False otherwise", rtype=bool) + @docval( + {"name": "namespace", "type": str, "doc": "the name of the namespace containing the data_type"}, + {"name": "data_type", "type": str, "doc": "the data_type to check"}, + {"name": "parent_data_type", "type": str, "doc": "the potential parent data_type"}, + returns="True if *data_type* is a sub `data_type` of *parent_data_type*, False otherwise", + rtype=bool, + ) def is_sub_data_type(self, **kwargs): - ''' + """ Return whether or not *data_type* is a sub `data_type` of *parent_data_type* - ''' - ns, dt, parent_dt = getargs('namespace', 'data_type', 'parent_data_type', kwargs) + """ + ns, dt, parent_dt = getargs("namespace", "data_type", "parent_data_type", kwargs) hier = self.get_hierarchy(ns, dt) return parent_dt in hier @docval(rtype=tuple) def get_sources(self, **kwargs): - ''' + """ Get all the source specification files that were loaded in this catalog - ''' + """ return tuple(self.__loaded_specs.keys()) - @docval({'name': 'namespace', 'type': str, 'doc': 'the name of the namespace'}, - rtype=tuple) + @docval( + {"name": "namespace", "type": str, "doc": "the name of the namespace"}, + rtype=tuple, + ) def get_namespace_sources(self, **kwargs): - ''' + """ Get all the source specifications that were loaded for a given namespace - ''' - namespace = getargs('namespace', kwargs) + """ + namespace = getargs("namespace", kwargs) return tuple(self.__included_sources[namespace]) - @docval({'name': 'source', 'type': str, 'doc': 'the name of the source'}, - rtype=tuple) + @docval( + {"name": "source", "type": str, "doc": "the name of the source"}, + rtype=tuple, + ) def get_types(self, **kwargs): - ''' + """ Get the types that were loaded from a given source - ''' - source = getargs('source', kwargs) + """ + source = getargs("source", kwargs) ret = self.__loaded_specs.get(source) if ret is not None: ret = tuple(ret) @@ -378,7 +457,7 @@ def __load_spec_file(self, reader, spec_source, catalog, types_to_load=None, res def __reg_spec(spec_cls, spec_dict): dt_def = spec_dict.get(spec_cls.def_key()) if dt_def is None: - msg = 'No data type def key found in spec %s' % spec_source + msg = "No data type def key found in spec %s" % spec_source raise ValueError(msg) if types_to_load and dt_def not in types_to_load: return @@ -390,12 +469,12 @@ def __reg_spec(spec_cls, spec_dict): if ret is None: ret = dict() # this is used as an ordered set -- values are all none d = reader.read_spec(spec_source) - specs = d.get('datasets', list()) + specs = d.get("datasets", list()) for spec_dict in specs: self.__convert_spec_cls_keys(GroupSpec, self.__group_spec_cls, spec_dict) temp_dict = {k: None for k in __reg_spec(self.__dataset_spec_cls, spec_dict)} ret.update(temp_dict) - specs = d.get('groups', list()) + specs = d.get("groups", list()) for spec_dict in specs: self.__convert_spec_cls_keys(GroupSpec, self.__group_spec_cls, spec_dict) temp_dict = {k: None for k in __reg_spec(self.__group_spec_cls, spec_dict)} @@ -413,8 +492,7 @@ def __convert_spec_cls_keys(self, parent_cls, spec_cls, spec_dict): spec_dict[spec_cls.inc_key()] = spec_dict.pop(parent_cls.inc_key()) def __resolve_includes(self, spec_cls, spec_dict, catalog): - """Replace data type inc strings with the spec definition so the new spec is built with included fields. - """ + """Replace data type inc strings with the spec definition so the new spec is built with included fields.""" dt_def = spec_dict.get(spec_cls.def_key()) dt_inc = spec_dict.get(spec_cls.inc_key()) if dt_inc is not None and dt_def is not None: @@ -425,39 +503,48 @@ def __resolve_includes(self, spec_cls, spec_dict, catalog): # replace the inc key value from string to the inc spec so that the spec can be updated with all of the # attributes, datasets, groups, and links of the inc spec when spec_cls.build_spec(spec_dict) is called spec_dict[spec_cls.inc_key()] = parent_spec - for subspec_dict in spec_dict.get('groups', list()): + for subspec_dict in spec_dict.get("groups", list()): self.__resolve_includes(self.__group_spec_cls, subspec_dict, catalog) - for subspec_dict in spec_dict.get('datasets', list()): + for subspec_dict in spec_dict.get("datasets", list()): self.__resolve_includes(self.__dataset_spec_cls, subspec_dict, catalog) def __load_namespace(self, namespace, reader, resolve=True): - ns_name = namespace['name'] + ns_name = namespace["name"] if ns_name in self.__namespaces: # pragma: no cover raise KeyError("namespace '%s' already exists" % ns_name) catalog = SpecCatalog() included_types = dict() - for s in namespace['schema']: + for s in namespace["schema"]: # types_key may be different in each spec namespace, so check both the __spec_namespace_cls types key # and the parent SpecNamespace types key. NOTE: this does not handle more than one custom types key - types_to_load = s.get(self.__spec_namespace_cls.types_key(), s.get(SpecNamespace.types_key())) + types_to_load = s.get( + self.__spec_namespace_cls.types_key(), + s.get(SpecNamespace.types_key()), + ) if types_to_load is not None: # schema specifies specific types from 'source' or 'namespace' types_to_load = set(types_to_load) - if 'source' in s: + if "source" in s: # read specs from file - self.__load_spec_file(reader, s['source'], catalog, types_to_load=types_to_load, resolve=resolve) - self.__included_sources.setdefault(ns_name, list()).append(s['source']) - elif 'namespace' in s: + self.__load_spec_file( + reader, + s["source"], + catalog, + types_to_load=types_to_load, + resolve=resolve, + ) + self.__included_sources.setdefault(ns_name, list()).append(s["source"]) + elif "namespace" in s: # load specs from namespace try: - inc_ns = self.get_namespace(s['namespace']) + inc_ns = self.get_namespace(s["namespace"]) except KeyError as e: - raise ValueError("Could not load namespace '%s'" % s['namespace']) from e + raise ValueError("Could not load namespace '%s'" % s["namespace"]) from e if types_to_load is None: types_to_load = inc_ns.get_registered_types() # load all types in namespace registered_types = set() for ndt in types_to_load: self.__register_type(ndt, inc_ns, catalog, registered_types) - included_types[s['namespace']] = tuple(sorted(registered_types)) + included_types[s["namespace"]] = tuple(sorted(registered_types)) else: raise ValueError("Spec '%s' schema must have either 'source' or 'namespace' key" % ns_name) # construct namespace @@ -477,8 +564,8 @@ def __register_type(self, ndt, inc_ns, catalog, registered_types): catalog.register_spec(built_spec, spec_file) def __register_dependent_types(self, spec, inc_ns, catalog, registered_types): - """Ensure that classes for all types used by this type are registered - """ + """Ensure that classes for all types used by this type are registered""" + # TODO test cross-namespace registration... def __register_dependent_types_helper(spec, inc_ns, catalog, registered_types): if isinstance(spec, (GroupSpec, DatasetSpec)): @@ -490,26 +577,39 @@ def __register_dependent_types_helper(spec, inc_ns, catalog, registered_types): else: # spec is a LinkSpec self.__register_type(spec.target_type, inc_ns, catalog, registered_types) if isinstance(spec, GroupSpec): - for child_spec in (spec.groups + spec.datasets + spec.links): + for child_spec in spec.groups + spec.datasets + spec.links: __register_dependent_types_helper(child_spec, inc_ns, catalog, registered_types) if spec.data_type_inc is not None: self.__register_type(spec.data_type_inc, inc_ns, catalog, registered_types) if isinstance(spec, GroupSpec): - for child_spec in (spec.groups + spec.datasets + spec.links): + for child_spec in spec.groups + spec.datasets + spec.links: __register_dependent_types_helper(child_spec, inc_ns, catalog, registered_types) - @docval({'name': 'namespace_path', 'type': str, 'doc': 'the path to the file containing the namespaces(s) to load'}, - {'name': 'resolve', - 'type': bool, - 'doc': 'whether or not to include objects from included/parent spec objects', 'default': True}, - {'name': 'reader', - 'type': SpecReader, - 'doc': 'the class to user for reading specifications', 'default': None}, - returns='a dictionary describing the dependencies of loaded namespaces', rtype=dict) + @docval( + { + "name": "namespace_path", + "type": str, + "doc": "the path to the file containing the namespaces(s) to load", + }, + { + "name": "resolve", + "type": bool, + "doc": "whether or not to include objects from included/parent spec objects", + "default": True, + }, + { + "name": "reader", + "type": SpecReader, + "doc": "the class to user for reading specifications", + "default": None, + }, + returns="a dictionary describing the dependencies of loaded namespaces", + rtype=dict, + ) def load_namespaces(self, **kwargs): """Load the namespaces in the given file""" - namespace_path, resolve, reader = getargs('namespace_path', 'resolve', 'reader', kwargs) + namespace_path, resolve, reader = getargs("namespace_path", "resolve", "reader", kwargs) if reader is None: # load namespace definition from file if not os.path.exists(namespace_path): @@ -525,15 +625,17 @@ def load_namespaces(self, **kwargs): namespaces = reader.read_namespace(namespace_path) to_load = list() for ns in namespaces: - if ns['name'] in self.__namespaces: - if ns['version'] != self.__namespaces.get(ns['name'])['version']: + if ns["name"] in self.__namespaces: + if ns["version"] != self.__namespaces.get(ns["name"])["version"]: # warn if the cached namespace differs from the already loaded namespace - warn("Ignoring cached namespace '%s' version %s because version %s is already loaded." - % (ns['name'], ns['version'], self.__namespaces.get(ns['name'])['version'])) + warn( + "Ignoring cached namespace '%s' version %s because version %s is already loaded." + % (ns["name"], ns["version"], self.__namespaces.get(ns["name"])["version"]) + ) else: to_load.append(ns) # now load specs into namespace for ns in to_load: - ret[ns['name']] = self.__load_namespace(ns, reader, resolve=resolve) + ret[ns["name"]] = self.__load_namespace(ns, reader, resolve=resolve) self.__included_specs[ns_path_key] = ret return ret diff --git a/src/hdmf/spec/spec.py b/src/hdmf/spec/spec.py index 183245853..8ff74774c 100644 --- a/src/hdmf/spec/spec.py +++ b/src/hdmf/spec/spec.py @@ -4,17 +4,17 @@ from copy import deepcopy from warnings import warn -from ..utils import docval, getargs, popargs, get_docval +from ..utils import docval, get_docval, getargs, popargs NAME_WILDCARD = None # this is no longer used, but kept for backward compatibility -ZERO_OR_ONE = '?' -ZERO_OR_MANY = '*' -ONE_OR_MANY = '+' +ZERO_OR_ONE = "?" +ZERO_OR_MANY = "*" +ONE_OR_MANY = "+" DEF_QUANTITY = 1 FLAGS = { - 'zero_or_one': ZERO_OR_ONE, - 'zero_or_many': ZERO_OR_MANY, - 'one_or_many': ONE_OR_MANY + "zero_or_one": ZERO_OR_ONE, + "zero_or_many": ZERO_OR_MANY, + "one_or_many": ONE_OR_MANY, } @@ -25,35 +25,36 @@ class DtypeHelper: # hdmf.validate.validator.__allowable, and backend dtype maps # see https://hdmf-schema-language.readthedocs.io/en/latest/description.html#dtype primary_dtype_synonyms = { - 'float': ["float", "float32"], - 'double': ["double", "float64"], - 'short': ["int16", "short"], - 'int': ["int32", "int"], - 'long': ["int64", "long"], - 'utf': ["text", "utf", "utf8", "utf-8"], - 'ascii': ["ascii", "bytes"], - 'bool': ["bool"], - 'int8': ["int8"], - 'uint8': ["uint8"], - 'uint16': ["uint16"], - 'uint32': ["uint32", "uint"], - 'uint64': ["uint64"], - 'object': ['object'], - 'region': ['region'], - 'numeric': ['numeric'], - 'isodatetime': ["isodatetime", "datetime"] + "float": ["float", "float32"], + "double": ["double", "float64"], + "short": ["int16", "short"], + "int": ["int32", "int"], + "long": ["int64", "long"], + "utf": ["text", "utf", "utf8", "utf-8"], + "ascii": ["ascii", "bytes"], + "bool": ["bool"], + "int8": ["int8"], + "uint8": ["uint8"], + "uint16": ["uint16"], + "uint32": ["uint32", "uint"], + "uint64": ["uint64"], + "object": ["object"], + "region": ["region"], + "numeric": ["numeric"], + "isodatetime": ["isodatetime", "datetime"], } # List of recommended primary dtype strings. These are the keys of primary_dtype_string_synonyms recommended_primary_dtypes = list(primary_dtype_synonyms.keys()) # List of valid primary data type strings - valid_primary_dtypes = set(list(primary_dtype_synonyms.keys()) + - [vi for v in primary_dtype_synonyms.values() for vi in v]) + valid_primary_dtypes = set( + list(primary_dtype_synonyms.keys()) + [vi for v in primary_dtype_synonyms.values() for vi in v] + ) @staticmethod def simplify_cpd_type(cpd_type): - ''' + """ Transform a list of DtypeSpecs into a list of strings. Use for simple representation of compound type and validation. @@ -61,7 +62,7 @@ def simplify_cpd_type(cpd_type): :param cpd_type: The list of DtypeSpecs to simplify :type cpd_type: list - ''' + """ ret = list() for exp in cpd_type: exp_key = exp.dtype @@ -74,74 +75,77 @@ def simplify_cpd_type(cpd_type): def check_dtype(dtype): """Check that the dtype string is a reference or a valid primary dtype.""" if not isinstance(dtype, RefSpec) and dtype not in DtypeHelper.valid_primary_dtypes: - raise ValueError("dtype '%s' is not a valid primary data type. Allowed dtypes: %s" - % (dtype, str(DtypeHelper.valid_primary_dtypes))) + raise ValueError( + "dtype '%s' is not a valid primary data type. Allowed dtypes: %s" + % (dtype, str(DtypeHelper.valid_primary_dtypes)) + ) return dtype class ConstructableDict(dict, metaclass=ABCMeta): @classmethod def build_const_args(cls, spec_dict): - ''' Build constructor arguments for this ConstructableDict class from a dictionary ''' + """Build constructor arguments for this ConstructableDict class from a dictionary""" # main use cases are when spec_dict is a ConstructableDict or a spec dict read from a file return deepcopy(spec_dict) @classmethod def build_spec(cls, spec_dict): - ''' Build a Spec object from the given Spec dict ''' + """Build a Spec object from the given Spec dict""" # main use cases are when spec_dict is a ConstructableDict or a spec dict read from a file vargs = cls.build_const_args(spec_dict) kwargs = dict() # iterate through the Spec docval and construct kwargs based on matching values in spec_dict for x in get_docval(cls.__init__): - if x['name'] in vargs: - kwargs[x['name']] = vargs.get(x['name']) + if x["name"] in vargs: + kwargs[x["name"]] = vargs.get(x["name"]) return cls(**kwargs) class Spec(ConstructableDict): - ''' A base specification class - ''' - - @docval({'name': 'doc', 'type': str, 'doc': 'a description about what this specification represents'}, - {'name': 'name', 'type': str, 'doc': 'The name of this attribute', 'default': None}, - {'name': 'required', 'type': bool, 'doc': 'whether or not this attribute is required', 'default': True}, - {'name': 'parent', 'type': 'Spec', 'doc': 'the parent of this spec', 'default': None}) + """A base specification class""" + + @docval( + {"name": "doc", "type": str, "doc": "a description about what this specification represents"}, + {"name": "name", "type": str, "doc": "The name of this attribute", "default": None}, + {"name": "required", "type": bool, "doc": "whether or not this attribute is required", "default": True}, + {"name": "parent", "type": "Spec", "doc": "the parent of this spec", "default": None}, + ) def __init__(self, **kwargs): - name, doc, required, parent = getargs('name', 'doc', 'required', 'parent', kwargs) + name, doc, required, parent = getargs("name", "doc", "required", "parent", kwargs) super().__init__() - self['doc'] = doc + self["doc"] = doc if name is not None: - self['name'] = name + self["name"] = name if not required: - self['required'] = required + self["required"] = required self._parent = parent @property def doc(self): - ''' Documentation on what this Spec is specifying ''' - return self.get('doc', None) + """Documentation on what this Spec is specifying""" + return self.get("doc", None) @property def name(self): - ''' The name of the object being specified ''' - return self.get('name', None) + """The name of the object being specified""" + return self.get("name", None) @property def parent(self): - ''' The parent specification of this specification ''' + """The parent specification of this specification""" return self._parent @parent.setter def parent(self, spec): - ''' Set the parent of this specification ''' + """Set the parent of this specification""" if self._parent is not None: - raise AttributeError('Cannot re-assign parent.') + raise AttributeError("Cannot re-assign parent.") self._parent = spec @classmethod def build_const_args(cls, spec_dict): - ''' Build constructor arguments for this Spec class from a dictionary ''' + """Build constructor arguments for this Spec class from a dictionary""" ret = super().build_const_args(spec_dict) return ret @@ -167,169 +171,250 @@ def path(self): # return id(self) == id(other) -_target_type_key = 'target_type' +_target_type_key = "target_type" _ref_args = [ - {'name': _target_type_key, 'type': str, 'doc': 'the target type GroupSpec or DatasetSpec'}, - {'name': 'reftype', 'type': str, 'doc': 'the type of references this is i.e. region or object'}, + {"name": _target_type_key, "type": str, "doc": "the target type GroupSpec or DatasetSpec"}, + {"name": "reftype", "type": str, "doc": "the type of references this is i.e. region or object"}, ] class RefSpec(ConstructableDict): - __allowable_types = ('object', 'region') + __allowable_types = ("object", "region") @docval(*_ref_args) def __init__(self, **kwargs): - target_type, reftype = getargs(_target_type_key, 'reftype', kwargs) + target_type, reftype = getargs(_target_type_key, "reftype", kwargs) self[_target_type_key] = target_type if reftype not in self.__allowable_types: msg = "reftype must be one of the following: %s" % ", ".join(self.__allowable_types) raise ValueError(msg) - self['reftype'] = reftype + self["reftype"] = reftype @property def target_type(self): - '''The data_type of the target of the reference''' + """The data_type of the target of the reference""" return self[_target_type_key] @property def reftype(self): - '''The type of reference''' - return self['reftype'] + """The type of reference""" + return self["reftype"] - @docval(rtype=bool, returns='True if this RefSpec specifies a region reference, False otherwise') + @docval( + rtype=bool, + returns="True if this RefSpec specifies a region reference, False otherwise", + ) def is_region(self): - return self['reftype'] == 'region' + return self["reftype"] == "region" _attr_args = [ - {'name': 'name', 'type': str, 'doc': 'The name of this attribute'}, - {'name': 'doc', 'type': str, 'doc': 'a description about what this specification represents'}, - {'name': 'dtype', 'type': (str, RefSpec), 'doc': 'The data type of this attribute'}, - {'name': 'shape', 'type': (list, tuple), 'doc': 'the shape of this dataset', 'default': None}, - {'name': 'dims', 'type': (list, tuple), 'doc': 'the dimensions of this dataset', 'default': None}, - {'name': 'required', 'type': bool, - 'doc': 'whether or not this attribute is required. ignored when "value" is specified', 'default': True}, - {'name': 'parent', 'type': 'BaseStorageSpec', 'doc': 'the parent of this spec', 'default': None}, - {'name': 'value', 'type': None, 'doc': 'a constant value for this attribute', 'default': None}, - {'name': 'default_value', 'type': None, 'doc': 'a default value for this attribute', 'default': None} + {"name": "name", "type": str, "doc": "The name of this attribute"}, + { + "name": "doc", + "type": str, + "doc": "a description about what this specification represents", + }, + { + "name": "dtype", + "type": (str, RefSpec), + "doc": "The data type of this attribute", + }, + { + "name": "shape", + "type": (list, tuple), + "doc": "the shape of this dataset", + "default": None, + }, + { + "name": "dims", + "type": (list, tuple), + "doc": "the dimensions of this dataset", + "default": None, + }, + { + "name": "required", + "type": bool, + "doc": 'whether or not this attribute is required. ignored when "value" is specified', + "default": True, + }, + { + "name": "parent", + "type": "BaseStorageSpec", + "doc": "the parent of this spec", + "default": None, + }, + { + "name": "value", + "type": None, + "doc": "a constant value for this attribute", + "default": None, + }, + { + "name": "default_value", + "type": None, + "doc": "a default value for this attribute", + "default": None, + }, ] class AttributeSpec(Spec): - ''' Specification for attributes - ''' + """Specification for attributes""" @docval(*_attr_args) def __init__(self, **kwargs): - name, dtype, doc, dims, shape, required, parent, value, default_value = getargs( - 'name', 'dtype', 'doc', 'dims', 'shape', 'required', 'parent', 'value', 'default_value', kwargs) + name, dtype, doc, dims, shape = getargs("name", "dtype", "doc", "dims", "shape", kwargs) + required, parent, value, default_value = getargs("required", "parent", "value", "default_value", kwargs) super().__init__(doc, name=name, required=required, parent=parent) - self['dtype'] = DtypeHelper.check_dtype(dtype) + self["dtype"] = DtypeHelper.check_dtype(dtype) if value is not None: - self.pop('required', None) - self['value'] = value + self.pop("required", None) + self["value"] = value if default_value is not None: if value is not None: raise ValueError("cannot specify 'value' and 'default_value'") - self['default_value'] = default_value - self['required'] = False + self["default_value"] = default_value + self["required"] = False if shape is not None: - self['shape'] = shape + self["shape"] = shape if dims is not None: - self['dims'] = dims - if 'shape' not in self: - self['shape'] = tuple([None] * len(dims)) + self["dims"] = dims + if "shape" not in self: + self["shape"] = tuple([None] * len(dims)) if self.shape is not None and self.dims is not None: - if len(self['dims']) != len(self['shape']): + if len(self["dims"]) != len(self["shape"]): raise ValueError("'dims' and 'shape' must be the same length") @property def dtype(self): - ''' The data type of the attribute ''' - return self.get('dtype', None) + """The data type of the attribute""" + return self.get("dtype", None) @property def value(self): - ''' The constant value of the attribute. "None" if this attribute is not constant ''' - return self.get('value', None) + """The constant value of the attribute. "None" if this attribute is not constant""" + return self.get("value", None) @property def default_value(self): - ''' The default value of the attribute. "None" if this attribute has no default value ''' - return self.get('default_value', None) + """The default value of the attribute. "None" if this attribute has no default value""" + return self.get("default_value", None) @property def required(self): - ''' True if this attribute is required, False otherwise. ''' - return self.get('required', True) + """True if this attribute is required, False otherwise.""" + return self.get("required", True) @property def dims(self): - ''' The dimensions of this attribute's value ''' - return self.get('dims', None) + """The dimensions of this attribute's value""" + return self.get("dims", None) @property def shape(self): - ''' The shape of this attribute's value ''' - return self.get('shape', None) + """The shape of this attribute's value""" + return self.get("shape", None) @classmethod def build_const_args(cls, spec_dict): - ''' Build constructor arguments for this Spec class from a dictionary ''' + """Build constructor arguments for this Spec class from a dictionary""" ret = super().build_const_args(spec_dict) - if isinstance(ret['dtype'], dict): - ret['dtype'] = RefSpec.build_spec(ret['dtype']) + if isinstance(ret["dtype"], dict): + ret["dtype"] = RefSpec.build_spec(ret["dtype"]) return ret _attrbl_args = [ - {'name': 'doc', 'type': str, 'doc': 'a description about what this specification represents'}, - {'name': 'name', 'type': str, - 'doc': 'the name of this base storage container, allowed only if quantity is not \'%s\' or \'%s\'' - % (ONE_OR_MANY, ZERO_OR_MANY), 'default': None}, - {'name': 'default_name', 'type': str, - 'doc': 'The default name of this base storage container, used only if name is None', 'default': None}, - {'name': 'attributes', 'type': list, 'doc': 'the attributes on this group', 'default': list()}, - {'name': 'linkable', 'type': bool, 'doc': 'whether or not this group can be linked', 'default': True}, - {'name': 'quantity', 'type': (str, int), 'doc': 'the required number of allowed instance', 'default': 1}, - {'name': 'data_type_def', 'type': str, 'doc': 'the data type this specification represents', 'default': None}, - {'name': 'data_type_inc', 'type': (str, 'BaseStorageSpec'), - 'doc': 'the data type this specification extends', 'default': None}, + { + "name": "doc", + "type": str, + "doc": "a description about what this specification represents", + }, + { + "name": "name", + "type": str, + "doc": "the name of this base storage container, allowed only if quantity is not '%s' or '%s'" % ( + ONE_OR_MANY, + ZERO_OR_MANY, + ), + "default": None, + }, + { + "name": "default_name", + "type": str, + "doc": "The default name of this base storage container, used only if name is None", + "default": None, + }, + { + "name": "attributes", + "type": list, + "doc": "the attributes on this group", + "default": list(), + }, + { + "name": "linkable", + "type": bool, + "doc": "whether or not this group can be linked", + "default": True, + }, + { + "name": "quantity", + "type": (str, int), + "doc": "the required number of allowed instance", + "default": 1, + }, + { + "name": "data_type_def", + "type": str, + "doc": "the data type this specification represents", + "default": None, + }, + { + "name": "data_type_inc", + "type": (str, "BaseStorageSpec"), + "doc": "the data type this specification extends", + "default": None, + }, ] class BaseStorageSpec(Spec): - ''' A specification for any object that can hold attributes. ''' + """A specification for any object that can hold attributes.""" - __inc_key = 'data_type_inc' - __def_key = 'data_type_def' - __type_key = 'data_type' - __id_key = 'object_id' + __inc_key = "data_type_inc" + __def_key = "data_type_def" + __type_key = "data_type" + __id_key = "object_id" @docval(*_attrbl_args) def __init__(self, **kwargs): - name, doc, quantity, attributes, linkable, data_type_def, data_type_inc = \ - getargs('name', 'doc', 'quantity', 'attributes', 'linkable', 'data_type_def', 'data_type_inc', kwargs) + name, doc, quantity, attributes = getargs("name", "doc", "quantity", "attributes", kwargs) + linkable, data_type_def, data_type_inc = getargs("linkable", "data_type_def", "data_type_inc", kwargs) if name is None and data_type_def is None and data_type_inc is None: - raise ValueError("Cannot create Group or Dataset spec with no name " - "without specifying '%s' and/or '%s'." % (self.def_key(), self.inc_key())) + raise ValueError( + "Cannot create Group or Dataset spec with no name without specifying '%s' and/or '%s'." + % (self.def_key(), self.inc_key()) + ) super().__init__(doc, name=name) - default_name = getargs('default_name', kwargs) + default_name = getargs("default_name", kwargs) if default_name: if name is not None: warn("found 'default_name' with 'name' - ignoring 'default_name'") else: - self['default_name'] = default_name + self["default_name"] = default_name self.__attributes = dict() if quantity in (ONE_OR_MANY, ZERO_OR_MANY): if name is not None: - raise ValueError("Cannot give specific name to something that can " - "exist multiple times: name='%s', quantity='%s'" % (name, quantity)) + raise ValueError( + "Cannot give specific name to something that can exist multiple times: name='%s', quantity='%s'" + % (name, quantity) + ) if quantity != DEF_QUANTITY: - self['quantity'] = quantity + self["quantity"] = quantity if not linkable: - self['linkable'] = False + self["linkable"] = False resolve = False if data_type_inc is not None: if isinstance(data_type_inc, BaseStorageSpec): @@ -337,7 +422,7 @@ def __init__(self, **kwargs): else: self[self.inc_key()] = data_type_inc if data_type_def is not None: - self.pop('required', None) + self.pop("required", None) self[self.def_key()] = data_type_def # resolve inherited and overridden fields only if data_type_inc is a spec # NOTE: this does not happen when loading specs from a file @@ -360,8 +445,8 @@ def __init__(self, **kwargs): @property def default_name(self): - '''The default name for this spec''' - return self.get('default_name', None) + """The default name for this spec""" + return self.get("default_name", None) @property def resolved(self): @@ -369,13 +454,13 @@ def resolved(self): @property def required(self): - ''' Whether or not the this spec represents a required field ''' + """Whether or not the this spec represents a required field""" return self.quantity not in (ZERO_OR_ONE, ZERO_OR_MANY) - @docval({'name': 'inc_spec', 'type': 'BaseStorageSpec', 'doc': 'the data type this specification represents'}) + @docval({"name": "inc_spec", "type": "BaseStorageSpec", "doc": "the data type this specification represents"}) def resolve_spec(self, **kwargs): """Add attributes from the inc_spec to this spec and track which attributes are new and overridden.""" - inc_spec = getargs('inc_spec', kwargs) + inc_spec = getargs("inc_spec", kwargs) for attribute in inc_spec.attributes: self.__new_attributes.discard(attribute.name) if attribute.name in self.__attributes: @@ -384,54 +469,54 @@ def resolve_spec(self, **kwargs): self.set_attribute(attribute) self.__resolved = True - @docval({'name': 'spec', 'type': (Spec, str), 'doc': 'the specification to check'}) + @docval({"name": "spec", "type": (Spec, str), "doc": "the specification to check"}) def is_inherited_spec(self, **kwargs): - ''' + """ Return True if this spec was inherited from the parent type, False otherwise. Returns False if the spec is not found. - ''' - spec = getargs('spec', kwargs) + """ + spec = getargs("spec", kwargs) if isinstance(spec, Spec): spec = spec.name if spec in self.__attributes: return self.is_inherited_attribute(spec) return False - @docval({'name': 'spec', 'type': (Spec, str), 'doc': 'the specification to check'}) + @docval({"name": "spec", "type": (Spec, str), "doc": "the specification to check"}) def is_overridden_spec(self, **kwargs): - ''' + """ Return True if this spec overrides a specification from the parent type, False otherwise. Returns False if the spec is not found. - ''' - spec = getargs('spec', kwargs) + """ + spec = getargs("spec", kwargs) if isinstance(spec, Spec): spec = spec.name if spec in self.__attributes: return self.is_overridden_attribute(spec) return False - @docval({'name': 'name', 'type': str, 'doc': 'the name of the attribute to check'}) + @docval({"name": "name", "type": str, "doc": "the name of the attribute to check"}) def is_inherited_attribute(self, **kwargs): - ''' + """ Return True if the attribute was inherited from the parent type, False otherwise. Raises a ValueError if the spec is not found. - ''' - name = getargs('name', kwargs) + """ + name = getargs("name", kwargs) if name not in self.__attributes: raise ValueError("Attribute '%s' not found" % name) return name not in self.__new_attributes - @docval({'name': 'name', 'type': str, 'doc': 'the name of the attribute to check'}) + @docval({"name": "name", "type": str, "doc": "the name of the attribute to check"}) def is_overridden_attribute(self, **kwargs): - ''' + """ Return True if the given attribute overrides the specification from the parent, False otherwise. Raises a ValueError if the spec is not found. - ''' - name = getargs('name', kwargs) + """ + name = getargs("name", kwargs) if name not in self.__attributes: raise ValueError("Attribute '%s' not found" % name) return name in self.__overridden_attributes @@ -441,89 +526,99 @@ def is_many(self): @classmethod def get_data_type_spec(cls, data_type_def): # unused - return AttributeSpec(cls.type_key(), 'the data type of this object', 'text', value=data_type_def) + return AttributeSpec( + cls.type_key(), + "the data type of this object", + "text", + value=data_type_def, + ) @classmethod def get_namespace_spec(cls): # unused - return AttributeSpec('namespace', 'the namespace for the data type of this object', 'text', required=False) + return AttributeSpec( + "namespace", + "the namespace for the data type of this object", + "text", + required=False, + ) @property def attributes(self): - ''' Tuple of attribute specifications for this specification ''' - return tuple(self.get('attributes', tuple())) + """Tuple of attribute specifications for this specification""" + return tuple(self.get("attributes", tuple())) @property def linkable(self): - ''' True if object can be a link, False otherwise ''' - return self.get('linkable', True) + """True if object can be a link, False otherwise""" + return self.get("linkable", True) @classmethod def id_key(cls): - ''' Get the key used to store data ID on an instance + """Get the key used to store data ID on an instance Override this method to use a different name for 'object_id' - ''' + """ return cls.__id_key @classmethod def type_key(cls): - ''' Get the key used to store data type on an instance + """Get the key used to store data type on an instance Override this method to use a different name for 'data_type'. HDMF supports combining schema that uses 'data_type' and at most one different name for 'data_type'. - ''' + """ return cls.__type_key @classmethod def inc_key(cls): - ''' Get the key used to define a data_type include. + """Get the key used to define a data_type include. Override this method to use a different keyword for 'data_type_inc'. HDMF supports combining schema that uses 'data_type_inc' and at most one different name for 'data_type_inc'. - ''' + """ return cls.__inc_key @classmethod def def_key(cls): - ''' Get the key used to define a data_type definition. + """Get the key used to define a data_type definition. Override this method to use a different keyword for 'data_type_def' HDMF supports combining schema that uses 'data_type_def' and at most one different name for 'data_type_def'. - ''' + """ return cls.__def_key @property def data_type_inc(self): - ''' The data type this specification inherits ''' + """The data type this specification inherits""" return self.get(self.inc_key()) @property def data_type_def(self): - ''' The data type this specification defines ''' + """The data type this specification defines""" return self.get(self.def_key(), None) @property def data_type(self): - ''' The data type of this specification ''' + """The data type of this specification""" return self.data_type_def or self.data_type_inc @property def quantity(self): - ''' The number of times the object being specified should be present ''' - return self.get('quantity', DEF_QUANTITY) + """The number of times the object being specified should be present""" + return self.get("quantity", DEF_QUANTITY) @docval(*_attr_args) def add_attribute(self, **kwargs): - ''' Add an attribute to this specification ''' + """Add an attribute to this specification""" spec = AttributeSpec(**kwargs) self.set_attribute(spec) return spec - @docval({'name': 'spec', 'type': AttributeSpec, 'doc': 'the specification for the attribute to add'}) + @docval({"name": "spec", "type": AttributeSpec, "doc": "the specification for the attribute to add"}) def set_attribute(self, **kwargs): - ''' Set an attribute on this specification ''' - spec = kwargs.get('spec') - attributes = self.setdefault('attributes', list()) + """Set an attribute on this specification""" + spec = kwargs.get("spec") + attributes = self.setdefault("attributes", list()) if spec.parent is not None: spec = AttributeSpec.build_spec(spec) # if attribute name already exists in self.__attributes, @@ -542,59 +637,59 @@ def set_attribute(self, **kwargs): if idx >= 0: attributes[idx] = spec else: # pragma: no cover - raise ValueError('%s in __attributes but not in spec record' % spec.name) + raise ValueError("%s in __attributes but not in spec record" % spec.name) else: attributes.append(spec) self.__attributes[spec.name] = spec spec.parent = self - @docval({'name': 'name', 'type': str, 'doc': 'the name of the attribute to the Spec for'}) + @docval({"name": "name", "type": str, "doc": "the name of the attribute to the Spec for"}) def get_attribute(self, **kwargs): - ''' Get an attribute on this specification ''' - name = getargs('name', kwargs) + """Get an attribute on this specification""" + name = getargs("name", kwargs) return self.__attributes.get(name) @classmethod def build_const_args(cls, spec_dict): - ''' Build constructor arguments for this Spec class from a dictionary ''' + """Build constructor arguments for this Spec class from a dictionary""" ret = super().build_const_args(spec_dict) - if 'attributes' in ret: - ret['attributes'] = [AttributeSpec.build_spec(sub_spec) for sub_spec in ret['attributes']] + if "attributes" in ret: + ret["attributes"] = [AttributeSpec.build_spec(sub_spec) for sub_spec in ret["attributes"]] return ret _dt_args = [ - {'name': 'name', 'type': str, 'doc': 'the name of this column'}, - {'name': 'doc', 'type': str, 'doc': 'a description about what this data type is'}, - {'name': 'dtype', 'type': (str, list, RefSpec), 'doc': 'the data type of this column'}, + {"name": "name", "type": str, "doc": "the name of this column"}, + {"name": "doc", "type": str, "doc": "a description about what this data type is"}, + {"name": "dtype", "type": (str, list, RefSpec), "doc": "the data type of this column"}, ] class DtypeSpec(ConstructableDict): - '''A class for specifying a component of a compound type''' + """A class for specifying a component of a compound type""" @docval(*_dt_args) def __init__(self, **kwargs): - doc, name, dtype = getargs('doc', 'name', 'dtype', kwargs) - self['doc'] = doc - self['name'] = name + doc, name, dtype = getargs("doc", "name", "dtype", kwargs) + self["doc"] = doc + self["name"] = name self.check_valid_dtype(dtype) - self['dtype'] = dtype + self["dtype"] = dtype @property def doc(self): - '''Documentation about this component''' - return self['doc'] + """Documentation about this component""" + return self["doc"] @property def name(self): - '''The name of this component''' - return self['name'] + """The name of this component""" + return self["name"] @property def dtype(self): - ''' The data type of this component''' - return self['dtype'] + """The data type of this component""" + return self["dtype"] @staticmethod def assertValidDtype(dtype): @@ -612,87 +707,150 @@ def check_valid_dtype(dtype): return True @staticmethod - @docval({'name': 'spec', 'type': (str, dict), 'doc': 'the spec object to check'}, is_method=False) + @docval( + {"name": "spec", "type": (str, dict), "doc": "the spec object to check"}, + is_method=False, + ) def is_ref(**kwargs): - spec = getargs('spec', kwargs) + spec = getargs("spec", kwargs) spec_is_ref = False if isinstance(spec, dict): if _target_type_key in spec: spec_is_ref = True - elif 'dtype' in spec and isinstance(spec['dtype'], dict) and _target_type_key in spec['dtype']: + elif "dtype" in spec and isinstance(spec["dtype"], dict) and _target_type_key in spec["dtype"]: spec_is_ref = True return spec_is_ref @classmethod def build_const_args(cls, spec_dict): - ''' Build constructor arguments for this Spec class from a dictionary ''' + """Build constructor arguments for this Spec class from a dictionary""" ret = super().build_const_args(spec_dict) - if isinstance(ret['dtype'], list): - ret['dtype'] = list(map(cls.build_const_args, ret['dtype'])) - elif isinstance(ret['dtype'], dict): - ret['dtype'] = RefSpec.build_spec(ret['dtype']) + if isinstance(ret["dtype"], list): + ret["dtype"] = list(map(cls.build_const_args, ret["dtype"])) + elif isinstance(ret["dtype"], dict): + ret["dtype"] = RefSpec.build_spec(ret["dtype"]) return ret _dataset_args = [ - {'name': 'doc', 'type': str, 'doc': 'a description about what this specification represents'}, - {'name': 'dtype', 'type': (str, list, RefSpec), - 'doc': 'The data type of this attribute. Use a list of DtypeSpecs to specify a compound data type.', - 'default': None}, - {'name': 'name', 'type': str, 'doc': 'The name of this dataset', 'default': None}, - {'name': 'default_name', 'type': str, 'doc': 'The default name of this dataset', 'default': None}, - {'name': 'shape', 'type': (list, tuple), 'doc': 'the shape of this dataset', 'default': None}, - {'name': 'dims', 'type': (list, tuple), 'doc': 'the dimensions of this dataset', 'default': None}, - {'name': 'attributes', 'type': list, 'doc': 'the attributes on this group', 'default': list()}, - {'name': 'linkable', 'type': bool, 'doc': 'whether or not this group can be linked', 'default': True}, - {'name': 'quantity', 'type': (str, int), 'doc': 'the required number of allowed instance', 'default': 1}, - {'name': 'default_value', 'type': None, 'doc': 'a default value for this dataset', 'default': None}, - {'name': 'data_type_def', 'type': str, 'doc': 'the data type this specification represents', 'default': None}, - {'name': 'data_type_inc', 'type': (str, 'DatasetSpec'), - 'doc': 'the data type this specification extends', 'default': None}, + { + "name": "doc", + "type": str, + "doc": "a description about what this specification represents", + }, + { + "name": "dtype", + "type": (str, list, RefSpec), + "doc": "The data type of this attribute. Use a list of DtypeSpecs to specify a compound data type.", + "default": None, + }, + { + "name": "name", + "type": str, + "doc": "The name of this dataset", + "default": None, + }, + { + "name": "default_name", + "type": str, + "doc": "The default name of this dataset", + "default": None, + }, + { + "name": "shape", + "type": (list, tuple), + "doc": "the shape of this dataset", + "default": None, + }, + { + "name": "dims", + "type": (list, tuple), + "doc": "the dimensions of this dataset", + "default": None, + }, + { + "name": "attributes", + "type": list, + "doc": "the attributes on this group", + "default": list(), + }, + { + "name": "linkable", + "type": bool, + "doc": "whether or not this group can be linked", + "default": True, + }, + { + "name": "quantity", + "type": (str, int), + "doc": "the required number of allowed instance", + "default": 1, + }, + { + "name": "default_value", + "type": None, + "doc": "a default value for this dataset", + "default": None, + }, + { + "name": "data_type_def", + "type": str, + "doc": "the data type this specification represents", + "default": None, + }, + { + "name": "data_type_inc", + "type": (str, "DatasetSpec"), + "doc": "the data type this specification extends", + "default": None, + }, ] class DatasetSpec(BaseStorageSpec): - ''' Specification for datasets + """Specification for datasets To specify a table-like dataset i.e. a compound data type. - ''' + """ @docval(*_dataset_args) def __init__(self, **kwargs): - doc, shape, dims, dtype, default_value = popargs('doc', 'shape', 'dims', 'dtype', 'default_value', kwargs) + doc, shape, dims, dtype, default_value = popargs("doc", "shape", "dims", "dtype", "default_value", kwargs) if shape is not None: - self['shape'] = shape + self["shape"] = shape if dims is not None: - self['dims'] = dims - if 'shape' not in self: - self['shape'] = tuple([None] * len(dims)) + self["dims"] = dims + if "shape" not in self: + self["shape"] = tuple([None] * len(dims)) if self.shape is not None and self.dims is not None: - if len(self['dims']) != len(self['shape']): + if len(self["dims"]) != len(self["shape"]): raise ValueError("'dims' and 'shape' must be the same length") if dtype is not None: if isinstance(dtype, list): # Dtype is a compound data type for _i, col in enumerate(dtype): if not isinstance(col, DtypeSpec): - msg = ('must use DtypeSpec if defining compound dtype - found %s at element %d' - % (type(col), _i)) + msg = "must use DtypeSpec if defining compound dtype - found %s at element %d" % ( + type(col), + _i, + ) raise ValueError(msg) else: DtypeHelper.check_dtype(dtype) - self['dtype'] = dtype + self["dtype"] = dtype super().__init__(doc, **kwargs) if default_value is not None: - self['default_value'] = default_value + self["default_value"] = default_value if self.name is not None: - valid_quant_vals = [1, 'zero_or_one', ZERO_OR_ONE] + valid_quant_vals = [1, "zero_or_one", ZERO_OR_ONE] if self.quantity not in valid_quant_vals: - raise ValueError("quantity %s invalid for spec with fixed name. Valid values are: %s" % - (self.quantity, str(valid_quant_vals))) + raise ValueError( + "quantity %s invalid for spec with fixed name. Valid values are: %s" + % (self.quantity, str(valid_quant_vals)) + ) @classmethod def __get_prec_level(cls, dtype): - m = re.search('[0-9]+', dtype) + m = re.search("[0-9]+", dtype) if m is not None: prec = int(m.group()) else: @@ -713,107 +871,126 @@ def __is_sub_dtype(cls, orig, new): return False return new_prec >= orig_prec - @docval({'name': 'inc_spec', 'type': 'DatasetSpec', 'doc': 'the data type this specification represents'}) + @docval({"name": "inc_spec", "type": "DatasetSpec", "doc": "the data type this specification represents"}) def resolve_spec(self, **kwargs): - inc_spec = getargs('inc_spec', kwargs) + inc_spec = getargs("inc_spec", kwargs) if isinstance(self.dtype, list): # merge the new types inc_dtype = inc_spec.dtype if isinstance(inc_dtype, str): - msg = 'Cannot extend simple data type to compound data type' + msg = "Cannot extend simple data type to compound data type" raise ValueError(msg) order = OrderedDict() if inc_dtype is not None: for dt in inc_dtype: - order[dt['name']] = dt + order[dt["name"]] = dt for dt in self.dtype: - name = dt['name'] + name = dt["name"] if name in order: # verify that the extension has supplied # a valid subtyping of existing type orig = order[name].dtype new = dt.dtype if not self.__is_sub_dtype(orig, new): - msg = 'Cannot extend %s to %s' % (str(orig), str(new)) + msg = "Cannot extend %s to %s" % (str(orig), str(new)) raise ValueError(msg) order[name] = dt - self['dtype'] = list(order.values()) + self["dtype"] = list(order.values()) super().resolve_spec(inc_spec) @property def dims(self): - ''' The dimensions of this Dataset ''' - return self.get('dims', None) + """The dimensions of this Dataset""" + return self.get("dims", None) @property def dtype(self): - ''' The data type of the Dataset ''' - return self.get('dtype', None) + """The data type of the Dataset""" + return self.get("dtype", None) @property def shape(self): - ''' The shape of the dataset ''' - return self.get('shape', None) + """The shape of the dataset""" + return self.get("shape", None) @property def default_value(self): - '''The default value of the dataset or None if not specified''' - return self.get('default_value', None) + """The default value of the dataset or None if not specified""" + return self.get("default_value", None) @classmethod def dtype_spec_cls(cls): - ''' The class to use when constructing DtypeSpec objects + """The class to use when constructing DtypeSpec objects - Override this if extending to use a class other than DtypeSpec to build - dataset specifications - ''' + Override this if extending to use a class other than DtypeSpec to build + dataset specifications + """ return DtypeSpec @classmethod def build_const_args(cls, spec_dict): - ''' Build constructor arguments for this Spec class from a dictionary ''' + """Build constructor arguments for this Spec class from a dictionary""" ret = super().build_const_args(spec_dict) - if 'dtype' in ret: - if isinstance(ret['dtype'], list): - ret['dtype'] = list(map(cls.dtype_spec_cls().build_spec, ret['dtype'])) - elif isinstance(ret['dtype'], dict): - ret['dtype'] = RefSpec.build_spec(ret['dtype']) + if "dtype" in ret: + if isinstance(ret["dtype"], list): + ret["dtype"] = list(map(cls.dtype_spec_cls().build_spec, ret["dtype"])) + elif isinstance(ret["dtype"], dict): + ret["dtype"] = RefSpec.build_spec(ret["dtype"]) return ret _link_args = [ - {'name': 'doc', 'type': str, 'doc': 'a description about what this link represents'}, - {'name': _target_type_key, 'type': (str, BaseStorageSpec), 'doc': 'the target type GroupSpec or DatasetSpec'}, - {'name': 'quantity', 'type': (str, int), 'doc': 'the required number of allowed instance', 'default': 1}, - {'name': 'name', 'type': str, 'doc': 'the name of this link', 'default': None} + { + "name": "doc", + "type": str, + "doc": "a description about what this link represents", + }, + { + "name": _target_type_key, + "type": (str, BaseStorageSpec), + "doc": "the target type GroupSpec or DatasetSpec", + }, + { + "name": "quantity", + "type": (str, int), + "doc": "the required number of allowed instance", + "default": 1, + }, + { + "name": "name", + "type": str, + "doc": "the name of this link", + "default": None, + }, ] class LinkSpec(Spec): - @docval(*_link_args) def __init__(self, **kwargs): - doc, target_type, name, quantity = popargs('doc', _target_type_key, 'name', 'quantity', kwargs) + doc, target_type, name, quantity = popargs("doc", _target_type_key, "name", "quantity", kwargs) super().__init__(doc, name, **kwargs) if isinstance(target_type, BaseStorageSpec): if target_type.data_type_def is None: - msg = ("'%s' must be a string or a GroupSpec or DatasetSpec with a '%s' key." - % (_target_type_key, target_type.def_key())) + msg = "'%s' must be a string or a GroupSpec or DatasetSpec with a '%s' key." % ( + _target_type_key, + target_type.def_key(), + ) raise ValueError(msg) self[_target_type_key] = target_type.data_type_def else: self[_target_type_key] = target_type if quantity != 1: - self['quantity'] = quantity + self["quantity"] = quantity @property def target_type(self): - ''' The data type of target specification ''' + """The data type of target specification""" return self.get(_target_type_key) @property def data_type_inc(self): - ''' The data type of target specification ''' + """The data type of target specification""" return self.get(_target_type_key) def is_many(self): @@ -821,54 +998,103 @@ def is_many(self): @property def quantity(self): - ''' The number of times the object being specified should be present ''' - return self.get('quantity', DEF_QUANTITY) + """The number of times the object being specified should be present""" + return self.get("quantity", DEF_QUANTITY) @property def required(self): - ''' Whether or not the this spec represents a required field ''' + """Whether or not the this spec represents a required field""" return self.quantity not in (ZERO_OR_ONE, ZERO_OR_MANY) _group_args = [ - {'name': 'doc', 'type': str, 'doc': 'a description about what this specification represents'}, { - 'name': 'name', - 'type': str, - 'doc': 'the name of the Group that is written to the file. If this argument is omitted, users will be ' - 'required to enter a ``name`` field when creating instances of this data type in the API. Another ' - 'option is to specify ``default_name``, in which case this name will be used as the name of the Group ' - 'if no other name is provided.', - 'default': None, + "name": "doc", + "type": str, + "doc": "a description about what this specification represents", + }, + { + "name": "name", + "type": str, + "doc": ( + "the name of the Group that is written to the file. If this argument is" + " omitted, users will be required to enter a ``name`` field when creating" + " instances of this data type in the API. Another option is to specify" + " ``default_name``, in which case this name will be used as the name of the" + " Group if no other name is provided." + ), + "default": None, + }, + { + "name": "default_name", + "type": str, + "doc": "The default name of this group", + "default": None, + }, + { + "name": "groups", + "type": list, + "doc": "the subgroups in this group", + "default": list(), + }, + { + "name": "datasets", + "type": list, + "doc": "the datasets in this group", + "default": list(), + }, + { + "name": "attributes", + "type": list, + "doc": "the attributes on this group", + "default": list(), + }, + { + "name": "links", + "type": list, + "doc": "the links in this group", + "default": list(), + }, + { + "name": "linkable", + "type": bool, + "doc": "whether or not this group can be linked", + "default": True, + }, + { + "name": "quantity", + "type": (str, int), + "doc": ( + "the allowable number of instance of this group in a certain location. See" + " table of options `here" + " `_." + " Note that if youspecify ``name``, ``quantity`` cannot be ``'*'``," + " ``'+'``, or an integer greater that 1, because you cannot have more than" + " one group of the same name in the same parent group." + ), + "default": 1, + }, + { + "name": "data_type_def", + "type": str, + "doc": "the data type this specification represents", + "default": None, }, - {'name': 'default_name', 'type': str, 'doc': 'The default name of this group', 'default': None}, - {'name': 'groups', 'type': list, 'doc': 'the subgroups in this group', 'default': list()}, - {'name': 'datasets', 'type': list, 'doc': 'the datasets in this group', 'default': list()}, - {'name': 'attributes', 'type': list, 'doc': 'the attributes on this group', 'default': list()}, - {'name': 'links', 'type': list, 'doc': 'the links in this group', 'default': list()}, - {'name': 'linkable', 'type': bool, 'doc': 'whether or not this group can be linked', 'default': True}, { - 'name': 'quantity', - 'type': (str, int), - 'doc': "the allowable number of instance of this group in a certain location. See table of options " - "`here `_. Note that if you" - "specify ``name``, ``quantity`` cannot be ``'*'``, ``'+'``, or an integer greater that 1, because you " - "cannot have more than one group of the same name in the same parent group.", - 'default': 1, + "name": "data_type_inc", + "type": (str, "GroupSpec"), + "doc": "the data type this specification data_type_inc", + "default": None, }, - {'name': 'data_type_def', 'type': str, 'doc': 'the data type this specification represents', 'default': None}, - {'name': 'data_type_inc', 'type': (str, 'GroupSpec'), - 'doc': 'the data type this specification data_type_inc', 'default': None}, ] class GroupSpec(BaseStorageSpec): - ''' Specification for groups - ''' + """Specification for groups""" @docval(*_group_args) def __init__(self, **kwargs): - doc, groups, datasets, links = popargs('doc', 'groups', 'datasets', 'links', kwargs) + doc, groups, datasets, links = popargs("doc", "groups", "datasets", "links", kwargs) self.__data_types = dict() # for GroupSpec/DatasetSpec data_type_def/inc self.__target_types = dict() # for LinkSpec target_types self.__groups = dict() @@ -890,9 +1116,9 @@ def __init__(self, **kwargs): self.__overridden_groups = set() super().__init__(doc, **kwargs) - @docval({'name': 'inc_spec', 'type': 'GroupSpec', 'doc': 'the data type this specification represents'}) + @docval({"name": "inc_spec", "type": "GroupSpec", "doc": "the data type this specification represents"}) def resolve_spec(self, **kwargs): - inc_spec = getargs('inc_spec', kwargs) + inc_spec = getargs("inc_spec", kwargs) data_types = list() target_types = list() # resolve inherited datasets @@ -937,9 +1163,9 @@ def resolve_spec(self, **kwargs): dt = dt_spec.data_type_inc self.__new_data_types.discard(dt) existing_dt_spec = self.get_data_type(dt) - if (existing_dt_spec is None or - ((isinstance(existing_dt_spec, list) or existing_dt_spec.name is not None) and - dt_spec.name is None)): + if existing_dt_spec is None or ( + (isinstance(existing_dt_spec, list) or existing_dt_spec.name is not None) and dt_spec.name is None + ): if isinstance(dt_spec, DatasetSpec): self.set_dataset(dt_spec) else: @@ -949,79 +1175,93 @@ def resolve_spec(self, **kwargs): dt = link_spec.target_type self.__new_target_types.discard(dt) existing_dt_spec = self.get_target_type(dt) - if (existing_dt_spec is None or - (isinstance(existing_dt_spec, list) or existing_dt_spec.name is not None) and - link_spec.name is None): + if ( + existing_dt_spec is None + or (isinstance(existing_dt_spec, list) or existing_dt_spec.name is not None) + and link_spec.name is None + ): self.set_link(link_spec) super().resolve_spec(inc_spec) - @docval({'name': 'name', 'type': str, 'doc': 'the name of the dataset'}, - raises="ValueError, if 'name' is not part of this spec") + @docval( + {"name": "name", "type": str, "doc": "the name of the dataset"}, + raises="ValueError, if 'name' is not part of this spec", + ) def is_inherited_dataset(self, **kwargs): - '''Return true if a dataset with the given name was inherited''' - name = getargs('name', kwargs) + """Return true if a dataset with the given name was inherited""" + name = getargs("name", kwargs) if name not in self.__datasets: raise ValueError("Dataset '%s' not found in spec" % name) return name not in self.__new_datasets - @docval({'name': 'name', 'type': str, 'doc': 'the name of the dataset'}, - raises="ValueError, if 'name' is not part of this spec") + @docval( + {"name": "name", "type": str, "doc": "the name of the dataset"}, + raises="ValueError, if 'name' is not part of this spec", + ) def is_overridden_dataset(self, **kwargs): - '''Return true if a dataset with the given name overrides a specification from the parent type''' - name = getargs('name', kwargs) + """Return true if a dataset with the given name overrides a specification from the parent type""" + name = getargs("name", kwargs) if name not in self.__datasets: raise ValueError("Dataset '%s' not found in spec" % name) return name in self.__overridden_datasets - @docval({'name': 'name', 'type': str, 'doc': 'the name of the group'}, - raises="ValueError, if 'name' is not part of this spec") + @docval( + {"name": "name", "type": str, "doc": "the name of the group"}, + raises="ValueError, if 'name' is not part of this spec", + ) def is_inherited_group(self, **kwargs): - '''Return true if a group with the given name was inherited''' - name = getargs('name', kwargs) + """Return true if a group with the given name was inherited""" + name = getargs("name", kwargs) if name not in self.__groups: raise ValueError("Group '%s' not found in spec" % name) return name not in self.__new_groups - @docval({'name': 'name', 'type': str, 'doc': 'the name of the group'}, - raises="ValueError, if 'name' is not part of this spec") + @docval( + {"name": "name", "type": str, "doc": "the name of the group"}, + raises="ValueError, if 'name' is not part of this spec", + ) def is_overridden_group(self, **kwargs): - '''Return true if a group with the given name overrides a specification from the parent type''' - name = getargs('name', kwargs) + """Return true if a group with the given name overrides a specification from the parent type""" + name = getargs("name", kwargs) if name not in self.__groups: raise ValueError("Group '%s' not found in spec" % name) return name in self.__overridden_groups - @docval({'name': 'name', 'type': str, 'doc': 'the name of the link'}, - raises="ValueError, if 'name' is not part of this spec") + @docval( + {"name": "name", "type": str, "doc": "the name of the link"}, + raises="ValueError, if 'name' is not part of this spec", + ) def is_inherited_link(self, **kwargs): - '''Return true if a link with the given name was inherited''' - name = getargs('name', kwargs) + """Return true if a link with the given name was inherited""" + name = getargs("name", kwargs) if name not in self.__links: raise ValueError("Link '%s' not found in spec" % name) return name not in self.__new_links - @docval({'name': 'name', 'type': str, 'doc': 'the name of the link'}, - raises="ValueError, if 'name' is not part of this spec") + @docval( + {"name": "name", "type": str, "doc": "the name of the link"}, + raises="ValueError, if 'name' is not part of this spec", + ) def is_overridden_link(self, **kwargs): - '''Return true if a link with the given name overrides a specification from the parent type''' - name = getargs('name', kwargs) + """Return true if a link with the given name overrides a specification from the parent type""" + name = getargs("name", kwargs) if name not in self.__links: raise ValueError("Link '%s' not found in spec" % name) return name in self.__overridden_links - @docval({'name': 'spec', 'type': (Spec, str), 'doc': 'the specification to check'}) + @docval({"name": "spec", "type": (Spec, str), "doc": "the specification to check"}) def is_inherited_spec(self, **kwargs): - ''' Returns 'True' if specification was inherited from a parent type ''' - spec = getargs('spec', kwargs) + """Returns 'True' if specification was inherited from a parent type""" + spec = getargs("spec", kwargs) if isinstance(spec, Spec): name = spec.name - if name is None and hasattr(spec, 'data_type_def'): + if name is None and hasattr(spec, "data_type_def"): name = spec.data_type_def if name is None: # NOTE: this will return the target type for LinkSpecs name = spec.data_type_inc if name is None: # pragma: no cover # this should not be possible - raise ValueError('received Spec with wildcard name but no data_type_inc or data_type_def') + raise ValueError("received Spec with wildcard name but no data_type_inc or data_type_def") spec = name # if the spec has a name, it will be found in __links/__groups/__datasets before __data_types/__target_types if spec in self.__links: @@ -1049,10 +1289,10 @@ def is_inherited_spec(self, **kwargs): return True return False - @docval({'name': 'spec', 'type': (Spec, str), 'doc': 'the specification to check'}) + @docval({"name": "spec", "type": (Spec, str), "doc": "the specification to check"}) def is_overridden_spec(self, **kwargs): # noqa: C901 - ''' Returns 'True' if specification overrides a specification from the parent type ''' - spec = getargs('spec', kwargs) + """Returns 'True' if specification overrides a specification from the parent type""" + spec = getargs("spec", kwargs) if isinstance(spec, Spec): name = spec.name if name is None: @@ -1065,7 +1305,7 @@ def is_overridden_spec(self, **kwargs): # noqa: C901 name = spec.data_type_inc if name is None: # pragma: no cover # this should not happen - raise ValueError('received Spec with wildcard name but no data_type_inc or data_type_def') + raise ValueError("received Spec with wildcard name but no data_type_inc or data_type_def") spec = name # if the spec has a name, it will be found in __links/__groups/__datasets before __data_types/__target_types if spec in self.__links: @@ -1090,34 +1330,38 @@ def is_overridden_spec(self, **kwargs): # noqa: C901 return True return False - @docval({'name': 'spec', 'type': (BaseStorageSpec, str), 'doc': 'the specification to check'}) + @docval({"name": "spec", "type": (BaseStorageSpec, str), "doc": "the specification to check"}) def is_inherited_type(self, **kwargs): - ''' Returns True if `spec` represents a data type that was inherited ''' - spec = getargs('spec', kwargs) + """Returns True if `spec` represents a data type that was inherited""" + spec = getargs("spec", kwargs) if isinstance(spec, BaseStorageSpec): if spec.data_type_def is None: # why not also check data_type_inc? - raise ValueError('cannot check if something was inherited if it does not have a %s' % self.def_key()) + raise ValueError("cannot check if something was inherited if it does not have a %s" % self.def_key()) spec = spec.data_type_def return spec not in self.__new_data_types - @docval({'name': 'spec', 'type': (BaseStorageSpec, str), 'doc': 'the specification to check'}, - raises="ValueError, if 'name' is not part of this spec") + @docval( + {"name": "spec", "type": (BaseStorageSpec, str), "doc": "the specification to check"}, + raises="ValueError, if 'name' is not part of this spec", + ) def is_overridden_type(self, **kwargs): - ''' Returns True if `spec` represents a data type that overrides a specification from a parent type ''' + """Returns True if `spec` represents a data type that overrides a specification from a parent type""" return self.is_inherited_type(**kwargs) - @docval({'name': 'spec', 'type': (LinkSpec, str), 'doc': 'the specification to check'}) + @docval({"name": "spec", "type": (LinkSpec, str), "doc": "the specification to check"}) def is_inherited_target_type(self, **kwargs): - ''' Returns True if `spec` represents a target type that was inherited ''' - spec = getargs('spec', kwargs) + """Returns True if `spec` represents a target type that was inherited""" + spec = getargs("spec", kwargs) if isinstance(spec, LinkSpec): spec = spec.target_type return spec not in self.__new_target_types - @docval({'name': 'spec', 'type': (LinkSpec, str), 'doc': 'the specification to check'}, - raises="ValueError, if 'name' is not part of this spec") + @docval( + {"name": "spec", "type": (LinkSpec, str), "doc": "the specification to check"}, + raises="ValueError, if 'name' is not part of this spec", + ) def is_overridden_target_type(self, **kwargs): - ''' Returns True if `spec` represents a target type that overrides a specification from a parent type ''' + """Returns True if `spec` represents a target type that overrides a specification from a parent type""" return self.is_inherited_target_type(**kwargs) def __add_data_type_inc(self, spec): @@ -1130,9 +1374,9 @@ def __add_data_type_inc(self, spec): # stored in __data_types # it is not allowed to have multiple specs for a given data type and multiple are unnamed dt = None - if hasattr(spec, 'data_type_def') and spec.data_type_def is not None: + if hasattr(spec, "data_type_def") and spec.data_type_def is not None: dt = spec.data_type_def - elif hasattr(spec, 'data_type_inc') and spec.data_type_inc is not None: + elif hasattr(spec, "data_type_inc") and spec.data_type_inc is not None: dt = spec.data_type_inc if not dt: # pragma: no cover # this should not be possible @@ -1215,9 +1459,9 @@ def __add_target_type(self, spec): else: self.__target_types[dt] = spec - @docval({'name': 'data_type', 'type': str, 'doc': 'the data_type to retrieve'}) + @docval({"name": "data_type", "type": str, "doc": "the data_type to retrieve"}) def get_data_type(self, **kwargs): - ''' Get a specification by "data_type" + """Get a specification by "data_type" NOTE: If there is only one spec for a given data type, then it is returned. If there are multiple specs for a given data type and they are all named, then they are returned in a list. @@ -1225,13 +1469,13 @@ def get_data_type(self, **kwargs): The other named specs can be returned using get_group or get_dataset. NOTE: this method looks for an exact match of the data type and does not consider the type hierarchy. - ''' - ndt = getargs('data_type', kwargs) + """ + ndt = getargs("data_type", kwargs) return self.__data_types.get(ndt, None) - @docval({'name': 'target_type', 'type': str, 'doc': 'the target_type to retrieve'}) + @docval({"name": "target_type", "type": str, "doc": "the target_type to retrieve"}) def get_target_type(self, **kwargs): - ''' Get a specification by "target_type" + """Get a specification by "target_type" NOTE: If there is only one spec for a given target type, then it is returned. If there are multiple specs for a given target type and they are all named, then they are returned in a list. @@ -1239,36 +1483,36 @@ def get_target_type(self, **kwargs): The other named specs can be returned using get_link. NOTE: this method looks for an exact match of the target type and does not consider the type hierarchy. - ''' - ndt = getargs('target_type', kwargs) + """ + ndt = getargs("target_type", kwargs) return self.__target_types.get(ndt, None) @property def groups(self): - ''' The groups specified in this GroupSpec ''' - return tuple(self.get('groups', tuple())) + """The groups specified in this GroupSpec""" + return tuple(self.get("groups", tuple())) @property def datasets(self): - ''' The datasets specified in this GroupSpec ''' - return tuple(self.get('datasets', tuple())) + """The datasets specified in this GroupSpec""" + return tuple(self.get("datasets", tuple())) @property def links(self): - ''' The links specified in this GroupSpec ''' - return tuple(self.get('links', tuple())) + """The links specified in this GroupSpec""" + return tuple(self.get("links", tuple())) @docval(*_group_args) def add_group(self, **kwargs): - ''' Add a new specification for a subgroup to this group specification ''' + """Add a new specification for a subgroup to this group specification""" spec = self.__class__(**kwargs) self.set_group(spec) return spec - @docval({'name': 'spec', 'type': ('GroupSpec'), 'doc': 'the specification for the subgroup'}) + @docval({"name": "spec", "type": "GroupSpec", "doc": "the specification for the subgroup"}) def set_group(self, **kwargs): - ''' Add the given specification for a subgroup to this group specification ''' - spec = getargs('spec', kwargs) + """Add the given specification for a subgroup to this group specification""" + spec = getargs("spec", kwargs) if spec.parent is not None: spec = self.build_spec(spec) if spec.name is None: @@ -1282,26 +1526,26 @@ def set_group(self, **kwargs): if spec.data_type_inc is not None or spec.data_type_def is not None: self.__add_data_type_inc(spec) self.__groups[spec.name] = spec - self.setdefault('groups', list()).append(spec) + self.setdefault("groups", list()).append(spec) spec.parent = self - @docval({'name': 'name', 'type': str, 'doc': 'the name of the group to the Spec for'}) + @docval({"name": "name", "type": str, "doc": "the name of the group to the Spec for"}) def get_group(self, **kwargs): - ''' Get a specification for a subgroup to this group specification ''' - name = getargs('name', kwargs) + """Get a specification for a subgroup to this group specification""" + name = getargs("name", kwargs) return self.__groups.get(name, self.__links.get(name)) @docval(*_dataset_args) def add_dataset(self, **kwargs): - ''' Add a new specification for a dataset to this group specification ''' + """Add a new specification for a dataset to this group specification""" spec = self.dataset_spec_cls()(**kwargs) self.set_dataset(spec) return spec - @docval({'name': 'spec', 'type': 'DatasetSpec', 'doc': 'the specification for the dataset'}) + @docval({"name": "spec", "type": "DatasetSpec", "doc": "the specification for the dataset"}) def set_dataset(self, **kwargs): - ''' Add the given specification for a dataset to this group specification ''' - spec = getargs('spec', kwargs) + """Add the given specification for a dataset to this group specification""" + spec = getargs("spec", kwargs) if spec.parent is not None: spec = self.dataset_spec_cls().build_spec(spec) if spec.name is None: @@ -1315,67 +1559,67 @@ def set_dataset(self, **kwargs): if spec.data_type_inc is not None or spec.data_type_def is not None: self.__add_data_type_inc(spec) self.__datasets[spec.name] = spec - self.setdefault('datasets', list()).append(spec) + self.setdefault("datasets", list()).append(spec) spec.parent = self - @docval({'name': 'name', 'type': str, 'doc': 'the name of the dataset to the Spec for'}) + @docval({"name": "name", "type": str, "doc": "the name of the dataset to the Spec for"}) def get_dataset(self, **kwargs): - ''' Get a specification for a dataset to this group specification ''' - name = getargs('name', kwargs) + """Get a specification for a dataset to this group specification""" + name = getargs("name", kwargs) return self.__datasets.get(name, self.__links.get(name)) @docval(*_link_args) def add_link(self, **kwargs): - ''' Add a new specification for a link to this group specification ''' + """Add a new specification for a link to this group specification""" spec = self.link_spec_cls()(**kwargs) self.set_link(spec) return spec - @docval({'name': 'spec', 'type': 'LinkSpec', 'doc': 'the specification for the object to link to'}) + @docval({"name": "spec", "type": "LinkSpec", "doc": "the specification for the object to link to"}) def set_link(self, **kwargs): - ''' Add a given specification for a link to this group specification ''' - spec = getargs('spec', kwargs) + """Add a given specification for a link to this group specification""" + spec = getargs("spec", kwargs) if spec.parent is not None: spec = self.link_spec_cls().build_spec(spec) # NOTE named specs can be present in both __links and __target_types self.__add_target_type(spec) if spec.name is not None: self.__links[spec.name] = spec - self.setdefault('links', list()).append(spec) + self.setdefault("links", list()).append(spec) spec.parent = self - @docval({'name': 'name', 'type': str, 'doc': 'the name of the link to the Spec for'}) + @docval({"name": "name", "type": str, "doc": "the name of the link to the Spec for"}) def get_link(self, **kwargs): - ''' Get a specification for a link to this group specification ''' - name = getargs('name', kwargs) + """Get a specification for a link to this group specification""" + name = getargs("name", kwargs) return self.__links.get(name) @classmethod def dataset_spec_cls(cls): - ''' The class to use when constructing DatasetSpec objects + """The class to use when constructing DatasetSpec objects - Override this if extending to use a class other than DatasetSpec to build - dataset specifications - ''' + Override this if extending to use a class other than DatasetSpec to build + dataset specifications + """ return DatasetSpec @classmethod def link_spec_cls(cls): - ''' The class to use when constructing LinkSpec objects + """The class to use when constructing LinkSpec objects - Override this if extending to use a class other than LinkSpec to build - link specifications - ''' + Override this if extending to use a class other than LinkSpec to build + link specifications + """ return LinkSpec @classmethod def build_const_args(cls, spec_dict): - ''' Build constructor arguments for this Spec class from a dictionary ''' + """Build constructor arguments for this Spec class from a dictionary""" ret = super().build_const_args(spec_dict) - if 'datasets' in ret: - ret['datasets'] = list(map(cls.dataset_spec_cls().build_spec, ret['datasets'])) - if 'groups' in ret: - ret['groups'] = list(map(cls.build_spec, ret['groups'])) - if 'links' in ret: - ret['links'] = list(map(cls.link_spec_cls().build_spec, ret['links'])) + if "datasets" in ret: + ret["datasets"] = list(map(cls.dataset_spec_cls().build_spec, ret["datasets"])) + if "groups" in ret: + ret["groups"] = list(map(cls.build_spec, ret["groups"])) + if "links" in ret: + ret["links"] = list(map(cls.link_spec_cls().build_spec, ret["links"])) return ret diff --git a/src/hdmf/spec/write.py b/src/hdmf/spec/write.py index 352e883f5..0323fb5df 100644 --- a/src/hdmf/spec/write.py +++ b/src/hdmf/spec/write.py @@ -5,16 +5,16 @@ from abc import ABCMeta, abstractmethod from collections import OrderedDict from datetime import datetime + import ruamel.yaml as yaml +from ..utils import docval, getargs, popargs from .catalog import SpecCatalog from .namespace import SpecNamespace -from .spec import GroupSpec, DatasetSpec -from ..utils import docval, getargs, popargs +from .spec import DatasetSpec, GroupSpec class SpecWriter(metaclass=ABCMeta): - @abstractmethod def write_spec(self, spec_file_dict, path): pass @@ -25,16 +25,20 @@ def write_namespace(self, namespace, path): class YAMLSpecWriter(SpecWriter): - - @docval({'name': 'outdir', - 'type': str, - 'doc': 'the path to write the directory to output the namespace and specs too', 'default': '.'}) + @docval( + { + "name": "outdir", + "type": str, + "doc": "the path to write the directory to output the namespace and specs too", + "default": ".", + } + ) def __init__(self, **kwargs): - self.__outdir = getargs('outdir', kwargs) + self.__outdir = getargs("outdir", kwargs) def __dump_spec(self, specs, stream): specs_plain_dict = json.loads(json.dumps(specs)) - yaml_obj = yaml.YAML(typ='safe', pure=True) + yaml_obj = yaml.YAML(typ="safe", pure=True) yaml_obj.default_flow_style = False yaml_obj.dump(specs_plain_dict, stream) @@ -42,7 +46,7 @@ def write_spec(self, spec_file_dict, path): out_fullpath = os.path.join(self.__outdir, path) spec_plain_dict = json.loads(json.dumps(spec_file_dict)) sorted_data = self.sort_keys(spec_plain_dict) - with open(out_fullpath, 'w') as fd_write: + with open(out_fullpath, "w") as fd_write: yaml_obj = yaml.YAML(pure=True) yaml_obj.dump(sorted_data, fd_write) @@ -52,20 +56,20 @@ def write_namespace(self, namespace, path): :param namespace: SpecNamespace holding the key-value pairs that define the namespace :param path: File path to write the namespace to as YAML under the key 'namespaces' """ - with open(os.path.join(self.__outdir, path), 'w') as stream: + with open(os.path.join(self.__outdir, path), "w") as stream: # Convert the date to a string if necessary ns = namespace - if 'date' in namespace and isinstance(namespace['date'], datetime): + if "date" in namespace and isinstance(namespace["date"], datetime): ns = copy.copy(ns) # copy the namespace to avoid side-effects - ns['date'] = ns['date'].isoformat() - self.__dump_spec({'namespaces': [ns]}, stream) + ns["date"] = ns["date"].isoformat() + self.__dump_spec({"namespaces": [ns]}, stream) def reorder_yaml(self, path): """ Open a YAML file, load it as python data, sort the data alphabetically, and write it back out to the same path. """ - with open(path, 'rb') as fd_read: + with open(path, "rb") as fd_read: yaml_obj = yaml.YAML(pure=True) data = yaml_obj.load(fd_read) self.write_spec(data, path) @@ -73,26 +77,41 @@ def reorder_yaml(self, path): def sort_keys(self, obj): # Represent None as null def my_represent_none(self, data): - return self.represent_scalar(u'tag:yaml.org,2002:null', u'null') + return self.represent_scalar("tag:yaml.org,2002:null", "null") yaml.representer.RoundTripRepresenter.add_representer(type(None), my_represent_none) - order = ['neurodata_type_def', 'neurodata_type_inc', 'data_type_def', 'data_type_inc', - 'name', 'default_name', - 'dtype', 'target_type', 'dims', 'shape', 'default_value', 'value', 'doc', - 'required', 'quantity', 'attributes', 'datasets', 'groups', 'links'] + order = [ + "neurodata_type_def", + "neurodata_type_inc", + "data_type_def", + "data_type_inc", + "name", + "default_name", + "dtype", + "target_type", + "dims", + "shape", + "default_value", + "value", + "doc", + "required", + "quantity", + "attributes", + "datasets", + "groups", + "links", + ] if isinstance(obj, dict): keys = list(obj.keys()) for k in order[::-1]: if k in keys: keys.remove(k) keys.insert(0, k) - if 'neurodata_type_def' not in keys and 'name' in keys: - keys.remove('name') - keys.insert(0, 'name') - return yaml.comments.CommentedMap( - yaml.compat.ordereddict([(k, self.sort_keys(obj[k])) for k in keys]) - ) + if "neurodata_type_def" not in keys and "name" in keys: + keys.remove("name") + keys.insert(0, "name") + return yaml.comments.CommentedMap(yaml.compat.ordereddict([(k, self.sort_keys(obj[k])) for k in keys])) elif isinstance(obj, list): return [self.sort_keys(v) for v in obj] elif isinstance(obj, tuple): @@ -102,64 +121,131 @@ def my_represent_none(self, data): class NamespaceBuilder: - ''' A class for building namespace and spec files ''' - - @docval({'name': 'doc', 'type': str, 'doc': 'Description about what the namespace represents'}, - {'name': 'name', 'type': str, 'doc': 'Name of the namespace'}, - {'name': 'full_name', 'type': str, 'doc': 'Extended full name of the namespace', 'default': None}, - {'name': 'version', 'type': (str, tuple, list), 'doc': 'Version number of the namespace', 'default': None}, - {'name': 'author', 'type': (str, list), 'doc': 'Author or list of authors.', 'default': None}, - {'name': 'contact', 'type': (str, list), - 'doc': 'List of emails. Ordering should be the same as for author', 'default': None}, - {'name': 'date', 'type': (datetime, str), - 'doc': "Date last modified or released. Formatting is %Y-%m-%d %H:%M:%S, e.g, 2017-04-25 17:14:13", - 'default': None}, - {'name': 'namespace_cls', 'type': type, 'doc': 'the SpecNamespace type', 'default': SpecNamespace}) + """A class for building namespace and spec files""" + + @docval( + { + "name": "doc", + "type": str, + "doc": "Description about what the namespace represents", + }, + { + "name": "name", + "type": str, + "doc": "Name of the namespace", + }, + { + "name": "full_name", + "type": str, + "doc": "Extended full name of the namespace", + "default": None, + }, + { + "name": "version", + "type": (str, tuple, list), + "doc": "Version number of the namespace", + "default": None, + }, + { + "name": "author", + "type": (str, list), + "doc": "Author or list of authors.", + "default": None, + }, + { + "name": "contact", + "type": (str, list), + "doc": "List of emails. Ordering should be the same as for author", + "default": None, + }, + { + "name": "date", + "type": (datetime, str), + "doc": "Date last modified or released. Formatting is %Y-%m-%d %H:%M:%S, e.g, 2017-04-25 17:14:13", + "default": None, + }, + { + "name": "namespace_cls", + "type": type, + "doc": "the SpecNamespace type", + "default": SpecNamespace, + }, + ) def __init__(self, **kwargs): - ns_cls = popargs('namespace_cls', kwargs) - if kwargs['version'] is None: + ns_cls = popargs("namespace_cls", kwargs) + if kwargs["version"] is None: # version is required on write as of HDMF 1.5. this check should prevent the writing of namespace files # without a version - raise ValueError("Namespace '%s' missing key 'version'. Please specify a version for the extension." - % kwargs['name']) + raise ValueError( + "Namespace '%s' missing key 'version'. Please specify a version for the extension." % kwargs["name"] + ) self.__ns_args = copy.deepcopy(kwargs) self.__namespaces = OrderedDict() self.__sources = OrderedDict() self.__catalog = SpecCatalog() self.__dt_key = ns_cls.types_key() - @docval({'name': 'source', 'type': str, 'doc': 'the path to write the spec to'}, - {'name': 'spec', 'type': (GroupSpec, DatasetSpec), 'doc': 'the Spec to add'}) + @docval( + {"name": "source", "type": str, "doc": "the path to write the spec to"}, + {"name": "spec", "type": (GroupSpec, DatasetSpec), "doc": "the Spec to add"}, + ) def add_spec(self, **kwargs): - ''' Add a Spec to the namespace ''' - source, spec = getargs('source', 'spec', kwargs) + """Add a Spec to the namespace""" + source, spec = getargs("source", "spec", kwargs) self.__catalog.auto_register(spec, source) self.add_source(source) self.__sources[source].setdefault(self.__dt_key, list()).append(spec) - @docval({'name': 'source', 'type': str, 'doc': 'the path to write the spec to'}, - {'name': 'doc', 'type': str, 'doc': 'additional documentation for the source file', 'default': None}, - {'name': 'title', 'type': str, 'doc': 'optional heading to be used for the source', 'default': None}) + @docval( + { + "name": "source", + "type": str, + "doc": "the path to write the spec to", + }, + { + "name": "doc", + "type": str, + "doc": "additional documentation for the source file", + "default": None, + }, + { + "name": "title", + "type": str, + "doc": "optional heading to be used for the source", + "default": None, + }, + ) def add_source(self, **kwargs): - ''' Add a source file to the namespace ''' - source, doc, title = getargs('source', 'doc', 'title', kwargs) - if '/' in source or source[0] == '.': - raise ValueError('source must be a base file') - source_dict = {'source': source} + """Add a source file to the namespace""" + source, doc, title = getargs("source", "doc", "title", kwargs) + if "/" in source or source[0] == ".": + raise ValueError("source must be a base file") + source_dict = {"source": source} self.__sources.setdefault(source, source_dict) # Update the doc and title if given if doc is not None: - self.__sources[source]['doc'] = doc + self.__sources[source]["doc"] = doc if title is not None: - self.__sources[source]['title'] = doc - - @docval({'name': 'data_type', 'type': str, 'doc': 'the data type to include'}, - {'name': 'source', 'type': str, 'doc': 'the source file to include the type from', 'default': None}, - {'name': 'namespace', 'type': str, - 'doc': 'the namespace from which to include the data type', 'default': None}) + self.__sources[source]["title"] = doc + + @docval( + {"name": "data_type", "type": str, "doc": "the data type to include"}, + { + "name": "source", + "type": str, + "doc": "the source file to include the type from", + "default": None, + }, + { + "name": "namespace", + "type": str, + "doc": "the namespace from which to include the data type", + "default": None, + }, + ) def include_type(self, **kwargs): - ''' Include a data type from an existing namespace or source ''' - dt, src, ns = getargs('data_type', 'source', 'namespace', kwargs) + """Include a data type from an existing namespace or source""" + dt, src, ns = getargs("data_type", "source", "namespace", kwargs) if src is not None: self.add_source(src) self.__sources[src].setdefault(self.__dt_key, list()).append(dt) @@ -169,32 +255,40 @@ def include_type(self, **kwargs): else: raise ValueError("must specify 'source' or 'namespace' when including type") - @docval({'name': 'namespace', 'type': str, 'doc': 'the namespace to include'}) + @docval({"name": "namespace", "type": str, "doc": "the namespace to include"}) def include_namespace(self, **kwargs): - ''' Include an entire namespace ''' - namespace = getargs('namespace', kwargs) - self.__namespaces.setdefault(namespace, {'namespace': namespace}) - - @docval({'name': 'path', 'type': str, 'doc': 'the path to write the spec to'}, - {'name': 'outdir', - 'type': str, - 'doc': 'the path to write the directory to output the namespace and specs too', 'default': '.'}, - {'name': 'writer', - 'type': SpecWriter, - 'doc': 'the SpecWriter to use to write the namespace', 'default': None}) + """Include an entire namespace""" + namespace = getargs("namespace", kwargs) + self.__namespaces.setdefault(namespace, {"namespace": namespace}) + + @docval( + {"name": "path", "type": str, "doc": "the path to write the spec to"}, + { + "name": "outdir", + "type": str, + "doc": "the path to write the directory to output the namespace and specs too", + "default": ".", + }, + { + "name": "writer", + "type": SpecWriter, + "doc": "the SpecWriter to use to write the namespace", + "default": None, + }, + ) def export(self, **kwargs): - ''' Export the namespace to the given path. + """Export the namespace to the given path. All new specification source files will be written in the same directory as the given path. - ''' - ns_path, writer = getargs('path', 'writer', kwargs) + """ + ns_path, writer = getargs("path", "writer", kwargs) if writer is None: - writer = YAMLSpecWriter(outdir=getargs('outdir', kwargs)) + writer = YAMLSpecWriter(outdir=getargs("outdir", kwargs)) ns_args = copy.copy(self.__ns_args) - ns_args['schema'] = list() + ns_args["schema"] = list() for ns, info in self.__namespaces.items(): - ns_args['schema'].append(info) + ns_args["schema"].append(info) for path, info in self.__sources.items(): out = SpecFileBuilder() dts = list() @@ -203,35 +297,34 @@ def export(self, **kwargs): dts.append(spec) else: out.add_spec(spec) - item = {'source': path} - if 'doc' in info: - item['doc'] = info['doc'] - if 'title' in info: - item['title'] = info['title'] + item = {"source": path} + if "doc" in info: + item["doc"] = info["doc"] + if "title" in info: + item["title"] = info["title"] if out and dts: - raise ValueError('cannot include from source if writing to source') + raise ValueError("cannot include from source if writing to source") elif dts: item[self.__dt_key] = dts elif out: writer.write_spec(out, path) - ns_args['schema'].append(item) + ns_args["schema"].append(item) namespace = SpecNamespace.build_namespace(**ns_args) writer.write_namespace(namespace, ns_path) @property def name(self): - return self.__ns_args['name'] + return self.__ns_args["name"] class SpecFileBuilder(dict): - - @docval({'name': 'spec', 'type': (GroupSpec, DatasetSpec), 'doc': 'the Spec to add'}) + @docval({"name": "spec", "type": (GroupSpec, DatasetSpec), "doc": "the Spec to add"}) def add_spec(self, **kwargs): - spec = getargs('spec', kwargs) + spec = getargs("spec", kwargs) if isinstance(spec, GroupSpec): - self.setdefault('groups', list()).append(spec) + self.setdefault("groups", list()).append(spec) elif isinstance(spec, DatasetSpec): - self.setdefault('datasets', list()).append(spec) + self.setdefault("datasets", list()).append(spec) def export_spec(ns_builder, new_data_types, output_dir): @@ -247,11 +340,11 @@ def export_spec(ns_builder, new_data_types, output_dir): """ if len(new_data_types) == 0: - warnings.warn('No data types specified. Exiting.') + warnings.warn("No data types specified. Exiting.") return - ns_path = ns_builder.name + '.namespace.yaml' - ext_path = ns_builder.name + '.extensions.yaml' + ns_path = ns_builder.name + ".namespace.yaml" + ext_path = ns_builder.name + ".extensions.yaml" for data_type in new_data_types: ns_builder.add_spec(ext_path, data_type) diff --git a/src/hdmf/testing/__init__.py b/src/hdmf/testing/__init__.py index cdf746388..240089732 100644 --- a/src/hdmf/testing/__init__.py +++ b/src/hdmf/testing/__init__.py @@ -1,2 +1,2 @@ -from .testcase import TestCase, H5RoundTripMixin +from .testcase import H5RoundTripMixin, TestCase from .utils import remove_test_file diff --git a/src/hdmf/testing/testcase.py b/src/hdmf/testing/testcase.py index f36ecc186..5ebeb8ea1 100644 --- a/src/hdmf/testing/testcase.py +++ b/src/hdmf/testing/testcase.py @@ -1,16 +1,18 @@ -import numpy as np import os import re import unittest from abc import ABCMeta, abstractmethod -from .utils import remove_test_file +import numpy as np + from ..backends.hdf5 import HDF5IO from ..build import Builder -from ..common import validate as common_validate, get_manager +from ..common import get_manager +from ..common import validate as common_validate from ..container import AbstractContainer, Container, Data -from ..utils import get_docval_macro from ..data_utils import AbstractDataChunkIterator +from ..utils import get_docval_macro +from .utils import remove_test_file class TestCase(unittest.TestCase): @@ -24,7 +26,7 @@ def assertRaisesWith(self, exc_type, exc_msg, *args, **kwargs): assertRaisesRegex, but checks for an exact match. """ - return self.assertRaisesRegex(exc_type, '^%s$' % re.escape(exc_msg), *args, **kwargs) + return self.assertRaisesRegex(exc_type, "^%s$" % re.escape(exc_msg), *args, **kwargs) def assertWarnsWith(self, warn_type, exc_msg, *args, **kwargs): """ @@ -32,15 +34,17 @@ def assertWarnsWith(self, warn_type, exc_msg, *args, **kwargs): assertWarnsRegex, but checks for an exact match. """ - return self.assertWarnsRegex(warn_type, '^%s$' % re.escape(exc_msg), *args, **kwargs) - - def assertContainerEqual(self, - container1, - container2, - ignore_name=False, - ignore_hdmf_attrs=False, - ignore_string_to_byte=False, - message=None): + return self.assertWarnsRegex(warn_type, "^%s$" % re.escape(exc_msg), *args, **kwargs) + + def assertContainerEqual( + self, + container1, + container2, + ignore_name=False, + ignore_hdmf_attrs=False, + ignore_string_to_byte=False, + message=None, + ): """ Asserts that the two AbstractContainers have equal contents. This applies to both Container and Data types. @@ -62,7 +66,11 @@ def assertContainerEqual(self, if not ignore_name: self.assertEqual(container1.name, container2.name, message) if not ignore_hdmf_attrs: - self.assertEqual(container1.container_source, container2.container_source, message) + self.assertEqual( + container1.container_source, + container2.container_source, + message, + ) self.assertEqual(container1.object_id, container2.object_id, message) # NOTE: parent is not tested because it can lead to infinite loops if isinstance(container1, Container): @@ -74,17 +82,22 @@ def assertContainerEqual(self, with self.subTest(field=field, container_type=type1.__name__): f1 = getattr(container1, field) f2 = getattr(container2, field) - self._assert_field_equal(f1, f2, - ignore_hdmf_attrs=ignore_hdmf_attrs, - ignore_string_to_byte=ignore_string_to_byte, - message=message) - - def _assert_field_equal(self, - f1, - f2, - ignore_hdmf_attrs=False, - ignore_string_to_byte=False, - message=None): + self._assert_field_equal( + f1, + f2, + ignore_hdmf_attrs=ignore_hdmf_attrs, + ignore_string_to_byte=ignore_string_to_byte, + message=message, + ) + + def _assert_field_equal( + self, + f1, + f2, + ignore_hdmf_attrs=False, + ignore_string_to_byte=False, + message=None, + ): """ Internal helper function used to compare two fields from Container objects @@ -95,12 +108,15 @@ def _assert_field_equal(self, :param ignore_string_to_byte: ignore conversion of str to bytes and compare as unicode instead :param message: custom additional message to show when assertions as part of this assert are failing """ - array_data_types = get_docval_macro('array_data') - if (isinstance(f1, array_data_types) or isinstance(f2, array_data_types)): - self._assert_array_equal(f1, f2, - ignore_hdmf_attrs=ignore_hdmf_attrs, - ignore_string_to_byte=ignore_string_to_byte, - message=message) + array_data_types = get_docval_macro("array_data") + if isinstance(f1, array_data_types) or isinstance(f2, array_data_types): + self._assert_array_equal( + f1, + f2, + ignore_hdmf_attrs=ignore_hdmf_attrs, + ignore_string_to_byte=ignore_string_to_byte, + message=message, + ) elif isinstance(f1, dict) and len(f1) and isinstance(f1.values()[0], Container): self.assertIsInstance(f2, dict, message) f1_keys = set(f1.keys()) @@ -108,31 +124,42 @@ def _assert_field_equal(self, self.assertSetEqual(f1_keys, f2_keys, message) for k in f1_keys: with self.subTest(module_name=k): - self.assertContainerEqual(f1[k], f2[k], - ignore_hdmf_attrs=ignore_hdmf_attrs, - ignore_string_to_byte=ignore_string_to_byte, - message=message) + self.assertContainerEqual( + f1[k], + f2[k], + ignore_hdmf_attrs=ignore_hdmf_attrs, + ignore_string_to_byte=ignore_string_to_byte, + message=message, + ) elif isinstance(f1, Container): - self.assertContainerEqual(f1, f2, - ignore_hdmf_attrs=ignore_hdmf_attrs, - ignore_string_to_byte=ignore_string_to_byte, - message=message) + self.assertContainerEqual( + f1, + f2, + ignore_hdmf_attrs=ignore_hdmf_attrs, + ignore_string_to_byte=ignore_string_to_byte, + message=message, + ) elif isinstance(f1, Data): - self._assert_data_equal(f1, f2, - ignore_hdmf_attrs=ignore_hdmf_attrs, - ignore_string_to_byte=ignore_string_to_byte, - message=message) + self._assert_data_equal( + f1, + f2, + ignore_hdmf_attrs=ignore_hdmf_attrs, + ignore_string_to_byte=ignore_string_to_byte, + message=message, + ) elif isinstance(f1, (float, np.floating)): np.testing.assert_allclose(f1, f2, err_msg=message) else: self.assertEqual(f1, f2, message) - def _assert_data_equal(self, - data1, - data2, - ignore_hdmf_attrs=False, - ignore_string_to_byte=False, - message=None): + def _assert_data_equal( + self, + data1, + data2, + ignore_hdmf_attrs=False, + ignore_string_to_byte=False, + message=None, + ): """ Internal helper function used to compare two :py:class:`~hdmf.container.Data` objects @@ -148,21 +175,28 @@ def _assert_data_equal(self, self.assertTrue(isinstance(data1, Data), message) self.assertTrue(isinstance(data2, Data), message) self.assertEqual(len(data1), len(data2), message) - self._assert_array_equal(data1.data, data2.data, - ignore_hdmf_attrs=ignore_hdmf_attrs, - ignore_string_to_byte=ignore_string_to_byte, - message=message) - self.assertContainerEqual(container1=data1, - container2=data2, - ignore_hdmf_attrs=ignore_hdmf_attrs, - message=message) - - def _assert_array_equal(self, - arr1, - arr2, - ignore_hdmf_attrs=False, - ignore_string_to_byte=False, - message=None): + self._assert_array_equal( + data1.data, + data2.data, + ignore_hdmf_attrs=ignore_hdmf_attrs, + ignore_string_to_byte=ignore_string_to_byte, + message=message, + ) + self.assertContainerEqual( + container1=data1, + container2=data2, + ignore_hdmf_attrs=ignore_hdmf_attrs, + message=message, + ) + + def _assert_array_equal( + self, + arr1, + arr2, + ignore_hdmf_attrs=False, + ignore_string_to_byte=False, + message=None, + ): """ Internal helper function used to check whether two arrays are equal @@ -173,8 +207,9 @@ def _assert_array_equal(self, :param ignore_string_to_byte: ignore conversion of str to bytes and compare as unicode instead :param message: custom additional message to show when assertions as part of this assert are failing """ - array_data_types = tuple([i for i in get_docval_macro('array_data') - if (i != list and i != tuple and i != AbstractDataChunkIterator)]) + array_data_types = tuple( + [i for i in get_docval_macro("array_data") if (i != list and i != tuple and i != AbstractDataChunkIterator)] + ) # We construct array_data_types this way to avoid explicit dependency on h5py, Zarr and other # I/O backends. Only list and tuple do not support [()] slicing, and AbstractDataChunkIterator # should never occur here. The effective value of array_data_types is then: @@ -189,9 +224,9 @@ def _assert_array_equal(self, else: if ignore_string_to_byte: if isinstance(arr1, bytes): - arr1 = arr1.decode('utf-8') + arr1 = arr1.decode("utf-8") if isinstance(arr2, bytes): - arr2 = arr2.decode('utf-8') + arr2 = arr2.decode("utf-8") self.assertEqual(arr1, arr2, message) # scalar else: self.assertEqual(len(arr1), len(arr2), message) @@ -207,27 +242,38 @@ def _assert_array_equal(self, else: for sub1, sub2 in zip(arr1, arr2): if isinstance(sub1, Container): - self.assertContainerEqual(sub1, sub2, - ignore_hdmf_attrs=ignore_hdmf_attrs, - ignore_string_to_byte=ignore_string_to_byte, - message=message) + self.assertContainerEqual( + sub1, + sub2, + ignore_hdmf_attrs=ignore_hdmf_attrs, + ignore_string_to_byte=ignore_string_to_byte, + message=message, + ) elif isinstance(sub1, Data): - self._assert_data_equal(sub1, sub2, - ignore_hdmf_attrs=ignore_hdmf_attrs, - ignore_string_to_byte=ignore_string_to_byte, - message=message) + self._assert_data_equal( + sub1, + sub2, + ignore_hdmf_attrs=ignore_hdmf_attrs, + ignore_string_to_byte=ignore_string_to_byte, + message=message, + ) else: - self._assert_array_equal(sub1, sub2, - ignore_hdmf_attrs=ignore_hdmf_attrs, - ignore_string_to_byte=ignore_string_to_byte, - message=message) - - def assertBuilderEqual(self, - builder1, - builder2, - check_path=True, - check_source=True, - message=None): + self._assert_array_equal( + sub1, + sub2, + ignore_hdmf_attrs=ignore_hdmf_attrs, + ignore_string_to_byte=ignore_string_to_byte, + message=message, + ) + + def assertBuilderEqual( + self, + builder1, + builder2, + check_path=True, + check_source=True, + message=None, + ): """ Test whether two builders are equal. Like assertDictEqual but also checks type, name, path, and source. @@ -274,8 +320,8 @@ def setUp(self): self.__manager = get_manager() self.container = self.setUpContainer() self.container_type = self.container.__class__.__name__ - self.filename = 'test_%s.h5' % self.container_type - self.export_filename = 'test_export_%s.h5' % self.container_type + self.filename = "test_%s.h5" % self.container_type + self.export_filename = "test_export_%s.h5" % self.container_type self.writer = None self.reader = None self.export_reader = None @@ -294,7 +340,7 @@ def tearDown(self): @abstractmethod def setUpContainer(self): """Return the Container to read/write.""" - raise NotImplementedError('Cannot run test unless setUpContainer is implemented') + raise NotImplementedError("Cannot run test unless setUpContainer is implemented") def test_roundtrip(self): """Test whether the container read from a written file is the same as the original file.""" @@ -316,16 +362,21 @@ def _test_roundtrip(self, read_container, export=False): if not export: self.assertContainerEqual(read_container, self.container, ignore_name=True) else: - self.assertContainerEqual(read_container, self.container, ignore_name=True, ignore_hdmf_attrs=True) + self.assertContainerEqual( + read_container, + self.container, + ignore_name=True, + ignore_hdmf_attrs=True, + ) self.validate(read_container._experimental) def roundtripContainer(self, cache_spec=False): """Write the container to an HDF5 file, read the container from the file, and return it.""" - with HDF5IO(self.filename, manager=get_manager(), mode='w') as write_io: + with HDF5IO(self.filename, manager=get_manager(), mode="w") as write_io: write_io.write(self.container, cache_spec=cache_spec) - self.reader = HDF5IO(self.filename, manager=get_manager(), mode='r') + self.reader = HDF5IO(self.filename, manager=get_manager(), mode="r") return self.reader.read() def roundtripExportContainer(self, cache_spec=False): @@ -338,20 +389,20 @@ def roundtripExportContainer(self, cache_spec=False): cache_spec=cache_spec, ) - self.export_reader = HDF5IO(self.export_filename, manager=get_manager(), mode='r') + self.export_reader = HDF5IO(self.export_filename, manager=get_manager(), mode="r") return self.export_reader.read() def validate(self, experimental=False): """Validate the written and exported files, if they exist.""" if os.path.exists(self.filename): - with HDF5IO(self.filename, manager=get_manager(), mode='r') as io: + with HDF5IO(self.filename, manager=get_manager(), mode="r") as io: errors = common_validate(io, experimental=experimental) if errors: for err in errors: raise Exception(err) if os.path.exists(self.export_filename): - with HDF5IO(self.filename, manager=get_manager(), mode='r') as io: + with HDF5IO(self.filename, manager=get_manager(), mode="r") as io: errors = common_validate(io, experimental=experimental) if errors: for err in errors: diff --git a/src/hdmf/testing/utils.py b/src/hdmf/testing/utils.py index e33f1c354..95649b153 100644 --- a/src/hdmf/testing/utils.py +++ b/src/hdmf/testing/utils.py @@ -7,6 +7,14 @@ def remove_test_file(path): This checks if the environment variable CLEAN_HDMF has been set to False before removing the file. If CLEAN_HDMF is set to False, it does not remove the file. """ - clean_flag_set = os.getenv('CLEAN_HDMF', True) not in ('False', 'false', 'FALSE', '0', 0, False) + false_options = ( + "False", + "false", + "FALSE", + "0", + 0, + False, + ) + clean_flag_set = os.getenv("CLEAN_HDMF", True) not in false_options if os.path.exists(path) and clean_flag_set: os.remove(path) diff --git a/src/hdmf/testing/validate_spec.py b/src/hdmf/testing/validate_spec.py index 89b8704bf..c3b4be9f6 100755 --- a/src/hdmf/testing/validate_spec.py +++ b/src/hdmf/testing/validate_spec.py @@ -17,21 +17,19 @@ def validate_spec(fpath_spec, fpath_schema): :param fpath_schema: path-like """ - schemaAbs = 'file://' + os.path.abspath(fpath_schema) + schemaAbs = "file://" + os.path.abspath(fpath_schema) - f_schema = open(fpath_schema, 'r') + f_schema = open(fpath_schema, "r") schema = json.load(f_schema) class FixResolver(jsonschema.RefResolver): def __init__(self): - jsonschema.RefResolver.__init__(self, - base_uri=schemaAbs, - referrer=None) + jsonschema.RefResolver.__init__(self, base_uri=schemaAbs, referrer=None) self.store[schemaAbs] = schema new_resolver = FixResolver() - f_nwb = open(fpath_spec, 'r') + f_nwb = open(fpath_spec, "r") instance = yaml.safe_load(f_nwb) jsonschema.validate(instance, schema, resolver=new_resolver) @@ -39,19 +37,23 @@ def __init__(self): def main(): parser = ArgumentParser(description="Validate an HDMF/NWB specification") - parser.add_argument("paths", type=str, nargs='+', help="yaml file paths") - parser.add_argument("-m", "--metaschema", type=str, - help=".json.schema file used to validate yaml files") + parser.add_argument("paths", type=str, nargs="+", help="yaml file paths") + parser.add_argument( + "-m", + "--metaschema", + type=str, + help=".json.schema file used to validate yaml files", + ) args = parser.parse_args() for path in args.paths: if os.path.isfile(path): validate_spec(path, args.metaschema) elif os.path.isdir(path): - for ipath in glob(os.path.join(path, '*.yaml')): + for ipath in glob(os.path.join(path, "*.yaml")): validate_spec(ipath, args.metaschema) else: - raise ValueError('path must be a valid file or directory') + raise ValueError("path must be a valid file or directory") if __name__ == "__main__": diff --git a/src/hdmf/utils.py b/src/hdmf/utils.py index 90b52b706..3634fa6b2 100644 --- a/src/hdmf/utils.py +++ b/src/hdmf/utils.py @@ -8,23 +8,23 @@ import h5py import numpy as np - __macros = { - 'array_data': [np.ndarray, list, tuple, h5py.Dataset], - 'scalar_data': [str, int, float, bytes, bool], - 'data': [] + "array_data": [np.ndarray, list, tuple, h5py.Dataset], + "scalar_data": [str, int, float, bytes, bool], + "data": [], } try: # optionally accept zarr.Array as array data to support conversion of data from Zarr to HDMF import zarr - __macros['array_data'].append(zarr.Array) + + __macros["array_data"].append(zarr.Array) except ImportError: pass # code to signify how to handle positional arguments in docval -AllowPositional = Enum('AllowPositional', 'ALLOWED WARNING ERROR') +AllowPositional = Enum("AllowPositional", "ALLOWED WARNING ERROR") __supported_bool_types = (bool, np.bool_) __supported_uint_types = (np.uint8, np.uint16, np.uint32, np.uint64) @@ -37,13 +37,13 @@ # non-deterministically. a future version of h5py will fix this. see #112 __supported_float_types.append(np.longdouble) __supported_float_types = tuple(__supported_float_types) -__allowed_enum_types = (__supported_bool_types + __supported_uint_types + __supported_int_types - + __supported_float_types + (str,)) +__allowed_enum_types = ( + __supported_bool_types + __supported_uint_types + __supported_int_types + __supported_float_types + (str,) +) def docval_macro(macro): - """Class decorator to add the class to a list of types associated with the key macro in the __macros dict - """ + """Class decorator to add the class to a list of types associated with the key macro in the __macros dict""" def _dec(cls): if macro not in __macros: @@ -70,31 +70,31 @@ def get_docval_macro(key=None): def __type_okay(value, argtype, allow_none=False): """Check a value against a type - The difference between this function and :py:func:`isinstance` is that - it allows specifying a type as a string. Furthermore, strings allow for specifying more general - types, such as a simple numeric type (i.e. ``argtype``="num"). + The difference between this function and :py:func:`isinstance` is that + it allows specifying a type as a string. Furthermore, strings allow for specifying more general + types, such as a simple numeric type (i.e. ``argtype``="num"). - Args: - value (any): the value to check - argtype (type, str): the type to check for - allow_none (bool): whether or not to allow None as a valid value + Args: + value (any): the value to check + argtype (type, str): the type to check for + allow_none (bool): whether or not to allow None as a valid value - Returns: - bool: True if value is a valid instance of argtype + Returns: + bool: True if value is a valid instance of argtype """ if value is None: return allow_none if isinstance(argtype, str): if argtype in __macros: return __type_okay(value, __macros[argtype], allow_none=allow_none) - elif argtype == 'uint': + elif argtype == "uint": return __is_uint(value) - elif argtype == 'int': + elif argtype == "int": return __is_int(value) - elif argtype == 'float': + elif argtype == "float": return __is_float(value) - elif argtype == 'bool': + elif argtype == "bool": return __is_bool(value) return argtype in [cls.__name__ for cls in value.__class__.__mro__] elif isinstance(argtype, type): @@ -170,22 +170,30 @@ def __check_enum(argval, arg): :return: None if the value validates successfully, error message if the value does not. """ - if argval not in arg['enum']: - return "forbidden value for '{}' (got {}, expected {})".format(arg['name'], __fmt_str_quotes(argval), - arg['enum']) + if argval not in arg["enum"]: + return "forbidden value for '{}' (got {}, expected {})".format( + arg["name"], __fmt_str_quotes(argval), arg["enum"] + ) def __fmt_str_quotes(x): """Return a string or list of strings where the input string or list of strings have single quotes around strings""" if isinstance(x, (list, tuple)): - return '{}'.format(x) + return "{}".format(x) if isinstance(x, str): return "'%s'" % x return str(x) -def __parse_args(validator, args, kwargs, enforce_type=True, enforce_shape=True, allow_extra=False, # noqa: C901 - allow_positional=AllowPositional.ALLOWED): +def __parse_args( # noqa: C901 + validator, + args, + kwargs, + enforce_type=True, + enforce_shape=True, + allow_extra=False, + allow_positional=AllowPositional.ALLOWED, +): """ Internal helper function used by the docval decorator to parse and validate function arguments @@ -217,32 +225,33 @@ def __parse_args(validator, args, kwargs, enforce_type=True, enforce_shape=True, try: # check for duplicates in docval - names = [x['name'] for x in validator] - duplicated = [item for item, count in collections.Counter(names).items() - if count > 1] + names = [x["name"] for x in validator] + duplicated = [item for item, count in collections.Counter(names).items() if count > 1] if duplicated: - raise ValueError( - 'The following names are duplicated: {}'.format(duplicated)) + raise ValueError("The following names are duplicated: {}".format(duplicated)) if allow_extra: # extra keyword arguments are allowed so do not consider them when checking number of args if len(args) > len(validator): raise TypeError( - 'Expected at most %d arguments %r, got %d positional' % (len(validator), names, len(args)) + "Expected at most %d arguments %r, got %d positional" % (len(validator), names, len(args)) ) else: # allow for keyword args if len(args) + len(kwargs) > len(validator): raise TypeError( - 'Expected at most %d arguments %r, got %d: %d positional and %d keyword %s' + "Expected at most %d arguments %r, got %d: %d positional and %d keyword %s" % (len(validator), names, len(args) + len(kwargs), len(args), len(kwargs), sorted(kwargs)) ) if args: if allow_positional == AllowPositional.WARNING: - msg = ('Using positional arguments for this method is discouraged and will be deprecated ' - 'in a future major release. Please use keyword arguments to ensure future compatibility.') + msg = ( + "Using positional arguments for this method is discouraged and will" + " be deprecated in a future major release. Please use keyword" + " arguments to ensure future compatibility." + ) future_warnings.append(msg) elif allow_positional == AllowPositional.ERROR: - msg = 'Only keyword arguments (e.g., func(argname=value, ...)) are allowed for this method.' + msg = "Only keyword arguments (e.g., func(argname=value, ...)) are allowed for this method." syntax_errors.append(msg) # iterate through the docval specification and find a matching value in args / kwargs @@ -252,9 +261,9 @@ def __parse_args(validator, args, kwargs, enforce_type=True, enforce_shape=True, # process positional arguments of the docval specification (no default value) extras = dict(kwargs) while True: - if 'default' in arg: + if "default" in arg: break - argname = arg['name'] + argname = arg["name"] argval_set = False if argname in kwargs: # if this positional arg is specified by a keyword arg and there are remaining positional args that @@ -273,30 +282,35 @@ def __parse_args(validator, args, kwargs, enforce_type=True, enforce_shape=True, type_errors.append("missing argument '%s'" % argname) else: if enforce_type: - if not __type_okay(argval, arg['type']): + if not __type_okay(argval, arg["type"]): if argval is None: - fmt_val = (argname, __format_type(arg['type'])) + fmt_val = (argname, __format_type(arg["type"])) type_errors.append("None is not allowed for '%s' (expected '%s', not None)" % fmt_val) else: - fmt_val = (argname, type(argval).__name__, __format_type(arg['type'])) + fmt_val = ( + argname, + type(argval).__name__, + __format_type(arg["type"]), + ) type_errors.append("incorrect type for '%s' (got '%s', expected '%s')" % fmt_val) - if enforce_shape and 'shape' in arg: + if enforce_shape and "shape" in arg: valshape = get_data_shape(argval) while valshape is None: if argval is None: break if not hasattr(argval, argname): - fmt_val = (argval, argname, arg['shape']) - value_errors.append("cannot check shape of object '%s' for argument '%s' " - "(expected shape '%s')" % fmt_val) + fmt_val = (argval, argname, arg["shape"]) + value_errors.append( + "cannot check shape of object '%s' for argument '%s' (expected shape '%s')" % fmt_val + ) break # unpack, e.g. if TimeSeries is passed for arg 'data', then TimeSeries.data is checked argval = getattr(argval, argname) valshape = get_data_shape(argval) - if valshape is not None and not __shape_okay_multi(argval, arg['shape']): - fmt_val = (argname, valshape, arg['shape']) + if valshape is not None and not __shape_okay_multi(argval, arg["shape"]): + fmt_val = (argname, valshape, arg["shape"]) value_errors.append("incorrect shape for '%s' (got '%s', expected '%s')" % fmt_val) - if 'enum' in arg: + if "enum" in arg: err = __check_enum(argval, arg) if err: value_errors.append(err) @@ -308,7 +322,7 @@ def __parse_args(validator, args, kwargs, enforce_type=True, enforce_shape=True, # process arguments of the docval specification with a default value # NOTE: the default value will be deepcopied, so 'default': list() is safe unlike in normal python while True: - argname = arg['name'] + argname = arg["name"] if argname in kwargs: ret[argname] = kwargs.get(argname) extras.pop(argname, None) @@ -316,33 +330,42 @@ def __parse_args(validator, args, kwargs, enforce_type=True, enforce_shape=True, ret[argname] = args[argsi] argsi += 1 else: - ret[argname] = _copy.deepcopy(arg['default']) + ret[argname] = _copy.deepcopy(arg["default"]) argval = ret[argname] if enforce_type: - if not __type_okay(argval, arg['type'], arg['default'] is None or arg.get('allow_none', False)): - if argval is None and arg['default'] is None: - fmt_val = (argname, __format_type(arg['type'])) + if not __type_okay( + argval, + arg["type"], + arg["default"] is None or arg.get("allow_none", False), + ): + if argval is None and arg["default"] is None: + fmt_val = (argname, __format_type(arg["type"])) type_errors.append("None is not allowed for '%s' (expected '%s', not None)" % fmt_val) else: - fmt_val = (argname, type(argval).__name__, __format_type(arg['type'])) + fmt_val = ( + argname, + type(argval).__name__, + __format_type(arg["type"]), + ) type_errors.append("incorrect type for '%s' (got '%s', expected '%s')" % fmt_val) - if enforce_shape and 'shape' in arg and argval is not None: + if enforce_shape and "shape" in arg and argval is not None: valshape = get_data_shape(argval) while valshape is None: if argval is None: break if not hasattr(argval, argname): - fmt_val = (argval, argname, arg['shape']) - value_errors.append("cannot check shape of object '%s' for argument '%s' (expected shape '%s')" - % fmt_val) + fmt_val = (argval, argname, arg["shape"]) + value_errors.append( + "cannot check shape of object '%s' for argument '%s' (expected shape '%s')" % fmt_val + ) break # unpack, e.g. if TimeSeries is passed for arg 'data', then TimeSeries.data is checked argval = getattr(argval, argname) valshape = get_data_shape(argval) - if valshape is not None and not __shape_okay_multi(argval, arg['shape']): - fmt_val = (argname, valshape, arg['shape']) + if valshape is not None and not __shape_okay_multi(argval, arg["shape"]): + fmt_val = (argname, valshape, arg["shape"]) value_errors.append("incorrect shape for '%s' (got '%s', expected '%s')" % fmt_val) - if 'enum' in arg and argval is not None: + if "enum" in arg and argval is not None: err = __check_enum(argval, arg) if err: value_errors.append(err) @@ -363,19 +386,24 @@ def __parse_args(validator, args, kwargs, enforce_type=True, enforce_shape=True, # allow_extra needs to be tracked on a function so that fmt_docval_args doesn't strip them out for key in extras.keys(): ret[key] = extras[key] - return {'args': ret, 'future_warnings': future_warnings, 'type_errors': type_errors, 'value_errors': value_errors, - 'syntax_errors': syntax_errors} + return { + "args": ret, + "future_warnings": future_warnings, + "type_errors": type_errors, + "value_errors": value_errors, + "syntax_errors": syntax_errors, + } -docval_idx_name = '__dv_idx__' -docval_attr_name = '__docval__' -__docval_args_loc = 'args' +docval_idx_name = "__dv_idx__" +docval_attr_name = "__docval__" +__docval_args_loc = "args" def get_docval(func, *args): - '''Get a copy of docval arguments for a function. + """Get a copy of docval arguments for a function. If args are supplied, return only docval arguments with value for 'name' key equal to the args - ''' + """ func_docval = getattr(func, docval_attr_name, None) if func_docval: if args: @@ -383,11 +411,11 @@ def get_docval(func, *args): try: return tuple(docval_idx[name] for name in args) except KeyError as ke: - raise ValueError('Function %s does not have docval argument %s' % (func.__name__, str(ke))) + raise ValueError("Function %s does not have docval argument %s" % (func.__name__, str(ke))) return tuple(func_docval[__docval_args_loc]) else: if args: - raise ValueError('Function %s has no docval arguments' % func.__name__) + raise ValueError("Function %s has no docval arguments" % func.__name__) return tuple() @@ -406,32 +434,38 @@ def get_docval(func, *args): def fmt_docval_args(func, kwargs): - ''' Separate positional and keyword arguments + """Separate positional and keyword arguments Useful for methods that wrap other methods - ''' - warnings.warn("fmt_docval_args will be deprecated in a future version of HDMF. Instead of using fmt_docval_args, " - "call the function directly with the kwargs. Please note that fmt_docval_args " - "removes all arguments not accepted by the function's docval, so if you are passing kwargs that " - "includes extra arguments and the function's docval does not allow extra arguments (allow_extra=True " - "is set), then you will need to pop the extra arguments out of kwargs before calling the function.", - PendingDeprecationWarning) + """ + warnings.warn( + ( + "fmt_docval_args will be deprecated in a future version of HDMF. Instead of" + " using fmt_docval_args, call the function directly with the kwargs. Please" + " note that fmt_docval_args removes all arguments not accepted by the" + " function's docval, so if you are passing kwargs that includes extra" + " arguments and the function's docval does not allow extra arguments" + " (allow_extra=True is set), then you will need to pop the extra arguments" + " out of kwargs before calling the function." + ), + PendingDeprecationWarning, + ) func_docval = getattr(func, docval_attr_name, None) ret_args = list() ret_kwargs = dict() kwargs_copy = _copy.copy(kwargs) if func_docval: for arg in func_docval[__docval_args_loc]: - val = kwargs_copy.pop(arg['name'], None) - if 'default' in arg: + val = kwargs_copy.pop(arg["name"], None) + if "default" in arg: if val is not None: - ret_kwargs[arg['name']] = val + ret_kwargs[arg["name"]] = val else: ret_args.append(val) - if func_docval['allow_extra']: + if func_docval["allow_extra"]: ret_kwargs.update(kwargs_copy) else: - raise ValueError('no docval found on %s' % str(func)) + raise ValueError("no docval found on %s" % str(func)) return ret_args, ret_kwargs @@ -464,12 +498,18 @@ def call_docval_func(func, kwargs): Extra keyword arguments are not passed to the function unless the function's docval has allow_extra=True. """ - warnings.warn("call_docval_func will be deprecated in a future version of HDMF. Instead of using call_docval_func, " - "call the function directly with the kwargs. Please note that call_docval_func " - "removes all arguments not accepted by the function's docval, so if you are passing kwargs that " - "includes extra arguments and the function's docval does not allow extra arguments (allow_extra=True " - "is set), then you will need to pop the extra arguments out of kwargs before calling the function.", - PendingDeprecationWarning) + warnings.warn( + ( + "call_docval_func will be deprecated in a future version of HDMF. Instead" + " of using call_docval_func, call the function directly with the kwargs." + " Please note that call_docval_func removes all arguments not accepted by" + " the function's docval, so if you are passing kwargs that includes extra" + " arguments and the function's docval does not allow extra arguments" + " (allow_extra=True is set), then you will need to pop the extra arguments" + " out of kwargs before calling the function." + ), + PendingDeprecationWarning, + ) with warnings.catch_warnings(record=True): # catch and ignore only PendingDeprecationWarnings from fmt_docval_args so that two # PendingDeprecationWarnings saying the same thing are not raised @@ -512,7 +552,7 @@ def __check_enum_argtype(argtype): def docval(*validator, **options): # noqa: C901 - '''A decorator for documenting and enforcing type for instance method arguments. + """A decorator for documenting and enforcing type for instance method arguments. This decorator takes a list of dictionaries that specify the method parameters. These dictionaries are used for enforcing type and building a Sphinx docstring. @@ -553,56 +593,72 @@ def foo(self, **kwargs): :param allow_extra: Allow extra arguments (Default=False) :param allow_positional: Allow positional arguments (Default=True) :param options: additional options for documenting and validating method parameters - ''' - enforce_type = options.pop('enforce_type', True) - enforce_shape = options.pop('enforce_shape', True) - returns = options.pop('returns', None) - rtype = options.pop('rtype', None) - is_method = options.pop('is_method', True) - allow_extra = options.pop('allow_extra', False) - allow_positional = options.pop('allow_positional', True) + """ + enforce_type = options.pop("enforce_type", True) + enforce_shape = options.pop("enforce_shape", True) + returns = options.pop("returns", None) + rtype = options.pop("rtype", None) + is_method = options.pop("is_method", True) + allow_extra = options.pop("allow_extra", False) + allow_positional = options.pop("allow_positional", True) def dec(func): _docval = _copy.copy(options) - _docval['allow_extra'] = allow_extra - _docval['allow_positional'] = allow_positional - func.__name__ = _docval.get('func_name', func.__name__) - func.__doc__ = _docval.get('doc', func.__doc__) + _docval["allow_extra"] = allow_extra + _docval["allow_positional"] = allow_positional + func.__name__ = _docval.get("func_name", func.__name__) + func.__doc__ = _docval.get("doc", func.__doc__) pos = list() kw = list() for a in validator: # catch unsupported keys - allowable_terms = ('name', 'doc', 'type', 'shape', 'enum', 'default', 'allow_none', 'help') + allowable_terms = ( + "name", + "doc", + "type", + "shape", + "enum", + "default", + "allow_none", + "help", + ) unsupported_terms = set(a.keys()) - set(allowable_terms) if unsupported_terms: - raise Exception('docval for {}: keys {} are not supported by docval'.format(a['name'], - sorted(unsupported_terms))) + raise Exception( + "docval for {}: keys {} are not supported by docval".format(a["name"], sorted(unsupported_terms)) + ) # check that arg type is valid try: - a['type'] = __resolve_type(a['type']) + a["type"] = __resolve_type(a["type"]) except Exception as e: - msg = "docval for %s: error parsing argument type: %s" % (a['name'], e.args[0]) + msg = "docval for %s: error parsing argument type: %s" % (a["name"], e.args[0]) raise Exception(msg) - if 'enum' in a: + if "enum" in a: # check that value for enum key is a list or tuple (cannot have only one allowed value) - if not isinstance(a['enum'], (list, tuple)): - msg = ('docval for %s: enum value must be a list or tuple (received %s)' - % (a['name'], type(a['enum']))) + if not isinstance(a["enum"], (list, tuple)): + msg = "docval for %s: enum value must be a list or tuple (received %s)" % ( + a["name"], + type(a["enum"]), + ) raise Exception(msg) # check that arg type is compatible with enum - if not __check_enum_argtype(a['type']): - msg = 'docval for {}: enum checking cannot be used with arg type {}'.format(a['name'], a['type']) + if not __check_enum_argtype(a["type"]): + msg = "docval for {}: enum checking cannot be used with arg type {}".format(a["name"], a["type"]) raise Exception(msg) # check that enum allowed values are allowed by arg type - if any([not __type_okay(x, a['type']) for x in a['enum']]): - msg = ('docval for {}: enum values are of types not allowed by arg type (got {}, ' - 'expected {})'.format(a['name'], [type(x) for x in a['enum']], a['type'])) + if any([not __type_okay(x, a["type"]) for x in a["enum"]]): + msg = ( + "docval for {}: enum values are of types not allowed by arg type (got {}, expected {})".format( + a["name"], [type(x) for x in a["enum"]], a["type"] + ) + ) raise Exception(msg) - if a.get('allow_none', False) and 'default' not in a: - msg = ('docval for {}: allow_none=True can only be set if a default value is provided.').format( - a['name']) + if a.get("allow_none", False) and "default" not in a: + msg = ("docval for {}: allow_none=True can only be set if a default value is provided.").format( + a["name"] + ) raise Exception(msg) - if 'default' in a: + if "default" in a: kw.append(a) else: pos.append(a) @@ -620,30 +676,35 @@ def _check_args(args, kwargs): enforce_type=enforce_type, enforce_shape=enforce_shape, allow_extra=allow_extra, - allow_positional=allow_positional + allow_positional=allow_positional, ) - parse_warnings = parsed.get('future_warnings') + parse_warnings = parsed.get("future_warnings") if parse_warnings: - msg = '%s: %s' % (func.__qualname__, ', '.join(parse_warnings)) + msg = "%s: %s" % (func.__qualname__, ", ".join(parse_warnings)) warnings.warn(msg, FutureWarning) - for error_type, ExceptionType in (('type_errors', TypeError), - ('value_errors', ValueError), - ('syntax_errors', SyntaxError)): + for error_type, ExceptionType in ( + ("type_errors", TypeError), + ("value_errors", ValueError), + ("syntax_errors", SyntaxError), + ): parse_err = parsed.get(error_type) if parse_err: - msg = '%s: %s' % (func.__qualname__, ', '.join(parse_err)) + msg = "%s: %s" % (func.__qualname__, ", ".join(parse_err)) raise ExceptionType(msg) - return parsed['args'] + return parsed["args"] # this code is intentionally separated to make stepping through lines of code using pdb easier if is_method: + def func_call(*args, **kwargs): pargs = _check_args(args, kwargs) return func(args[0], **pargs) + else: + def func_call(*args, **kwargs): pargs = _check_args(args, kwargs) return func(**pargs) @@ -652,31 +713,39 @@ def func_call(*args, **kwargs): if isinstance(rtype, type): _rtype = rtype.__name__ docstring = __googledoc(func, _docval[__docval_args_loc], returns=returns, rtype=_rtype) - docval_idx = {a['name']: a for a in _docval[__docval_args_loc]} # cache a name-indexed dictionary of args - setattr(func_call, '__doc__', docstring) - setattr(func_call, '__name__', func.__name__) + docval_idx = {a["name"]: a for a in _docval[__docval_args_loc]} # cache a name-indexed dictionary of args + setattr(func_call, "__doc__", docstring) + setattr(func_call, "__name__", func.__name__) setattr(func_call, docval_attr_name, _docval) setattr(func_call, docval_idx_name, docval_idx) - setattr(func_call, '__module__', func.__module__) + setattr(func_call, "__module__", func.__module__) return func_call return dec def __sig_arg(argval): - if 'default' in argval: - default = argval['default'] + if "default" in argval: + default = argval["default"] if isinstance(default, str): default = "'%s'" % default else: default = str(default) - return "%s=%s" % (argval['name'], default) + return "%s=%s" % (argval["name"], default) else: - return argval['name'] + return argval["name"] -def __builddoc(func, validator, docstring_fmt, arg_fmt, ret_fmt=None, returns=None, rtype=None): - '''Generate a Spinxy docstring''' +def __builddoc( + func, + validator, + docstring_fmt, + arg_fmt, + ret_fmt=None, + returns=None, + rtype=None, +): + """Generate a Spinxy docstring""" def to_str(argtype): if isinstance(argtype, type): @@ -691,12 +760,12 @@ def to_str(argtype): def __sphinx_arg(arg): fmt = dict() - fmt['name'] = arg.get('name') - fmt['doc'] = arg.get('doc') - if isinstance(arg['type'], tuple) or isinstance(arg['type'], list): - fmt['type'] = " or ".join(map(to_str, arg['type'])) + fmt["name"] = arg.get("name") + fmt["doc"] = arg.get("doc") + if isinstance(arg["type"], tuple) or isinstance(arg["type"], list): + fmt["type"] = " or ".join(map(to_str, arg["type"])) else: - fmt['type'] = to_str(arg['type']) + fmt["type"] = to_str(arg["type"]) return arg_fmt.format(**fmt) sig = "%s(%s)\n\n" % (func.__name__, ", ".join(map(__sig_arg, validator))) @@ -709,13 +778,18 @@ def __sphinx_arg(arg): def __sphinxdoc(func, validator, returns=None, rtype=None): - arg_fmt = (":param {name}: {doc}\n" - ":type {name}: {type}") - docstring_fmt = ("{description}\n\n" - "{args}\n") - ret_fmt = (":returns: {returns}\n" - ":rtype: {rtype}") - return __builddoc(func, validator, docstring_fmt, arg_fmt, ret_fmt=ret_fmt, returns=returns, rtype=rtype) + arg_fmt = ":param {name}: {doc}\n:type {name}: {type}" + docstring_fmt = "{description}\n\n{args}\n" + ret_fmt = ":returns: {returns}\n:rtype: {rtype}" + return __builddoc( + func, + validator, + docstring_fmt, + arg_fmt, + ret_fmt=ret_fmt, + returns=returns, + rtype=rtype, + ) def __googledoc(func, validator, returns=None, rtype=None): @@ -723,9 +797,16 @@ def __googledoc(func, validator, returns=None, rtype=None): docstring_fmt = "{description}\n\n" if len(validator) > 0: docstring_fmt += "Args:\n{args}\n" - ret_fmt = ("\nReturns:\n" - " {rtype}: {returns}") - return __builddoc(func, validator, docstring_fmt, arg_fmt, ret_fmt=ret_fmt, returns=returns, rtype=rtype) + ret_fmt = "\nReturns:\n {rtype}: {returns}" + return __builddoc( + func, + validator, + docstring_fmt, + arg_fmt, + ret_fmt=ret_fmt, + returns=returns, + rtype=rtype, + ) def getargs(*argnames): @@ -740,9 +821,9 @@ def getargs(*argnames): :return: a single value if there is only one argument, or a list of values corresponding to the given argument names """ if len(argnames) < 2: - raise ValueError('Must supply at least one key and a dict') + raise ValueError("Must supply at least one key and a dict") if not isinstance(argnames[-1], dict): - raise ValueError('Last argument must be a dict') + raise ValueError("Last argument must be a dict") kwargs = argnames[-1] if len(argnames) == 2: if argnames[0] not in kwargs: @@ -768,20 +849,20 @@ def popargs(*argnames): :return: a single value if there is only one argument, or a list of values corresponding to the given argument names """ if len(argnames) < 2: - raise ValueError('Must supply at least one key and a dict') + raise ValueError("Must supply at least one key and a dict") if not isinstance(argnames[-1], dict): - raise ValueError('Last argument must be a dict') + raise ValueError("Last argument must be a dict") kwargs = argnames[-1] if len(argnames) == 2: try: ret = kwargs.pop(argnames[0]) except KeyError as ke: - raise ValueError('Argument not found in dict: %s' % str(ke)) + raise ValueError("Argument not found in dict: %s" % str(ke)) return ret try: ret = [kwargs.pop(arg) for arg in argnames[:-1]] except KeyError as ke: - raise ValueError('Argument not found in dict: %s' % str(ke)) + raise ValueError("Argument not found in dict: %s" % str(ke)) return ret @@ -802,35 +883,35 @@ def popargs_to_dict(keys, argdict): try: ret[arg] = argdict.pop(arg) except KeyError as ke: - raise ValueError('Argument not found in dict: %s' % str(ke)) + raise ValueError("Argument not found in dict: %s" % str(ke)) return ret class ExtenderMeta(ABCMeta): """A metaclass that will extend the base class initialization - routine by executing additional functions defined in - classes that use this metaclass + routine by executing additional functions defined in + classes that use this metaclass - In general, this class should only be used by core developers. + In general, this class should only be used by core developers. """ - __preinit = '__preinit' + __preinit = "__preinit" @classmethod def pre_init(cls, func): setattr(func, cls.__preinit, True) return classmethod(func) - __postinit = '__postinit' + __postinit = "__postinit" @classmethod def post_init(cls, func): - '''A decorator for defining a routine to run after creation of a type object. + """A decorator for defining a routine to run after creation of a type object. An example use of this method would be to define a classmethod that gathers any defined methods or attributes after the base Python type construction (i.e. after :py:func:`type` has been called) - ''' + """ setattr(func, cls.__postinit, True) return classmethod(func) @@ -867,7 +948,7 @@ def get_data_shape(data, strict_no_data_load=False): def __get_shape_helper(local_data): shape = list() - if hasattr(local_data, '__len__'): + if hasattr(local_data, "__len__"): shape.append(len(local_data)) if len(local_data): el = next(iter(local_data)) @@ -876,13 +957,13 @@ def __get_shape_helper(local_data): return tuple(shape) # NOTE: data.maxshape will fail on empty h5py.Dataset without shape or maxshape. this will be fixed in h5py 3.0 - if hasattr(data, 'maxshape'): + if hasattr(data, "maxshape"): return data.maxshape - if hasattr(data, 'shape') and data.shape is not None: + if hasattr(data, "shape") and data.shape is not None: return data.shape if isinstance(data, dict): return None - if hasattr(data, '__len__') and not isinstance(data, (str, bytes)): + if hasattr(data, "__len__") and not isinstance(data, (str, bytes)): if not strict_no_data_load or isinstance(data, (list, tuple, set)): return __get_shape_helper(data) return None @@ -893,7 +974,7 @@ def pystr(s): Convert a string of characters to Python str object """ if isinstance(s, bytes): - return s.decode('utf-8') + return s.decode("utf-8") else: return s @@ -911,10 +992,10 @@ def to_uint_array(arr): return arr if np.issubdtype(arr.dtype, np.integer): if (arr < 0).any(): - raise ValueError('Cannot convert negative integer values to uint.') - dt = np.dtype('uint' + str(int(arr.dtype.itemsize*8))) # keep precision + raise ValueError("Cannot convert negative integer values to uint.") + dt = np.dtype("uint" + str(int(arr.dtype.itemsize * 8))) # keep precision return arr.astype(dt) - raise ValueError('Cannot convert array of dtype %s to uint.' % arr.dtype) + raise ValueError("Cannot convert array of dtype %s to uint." % arr.dtype) class LabelledDict(dict): @@ -959,22 +1040,39 @@ class LabelledDict(dict): ld['prop2 == b'] # Returns set([obj1, obj2]) - the set of all values v in ld where v.prop2 == 'b' """ - @docval({'name': 'label', 'type': str, 'doc': 'the label on this dictionary'}, - {'name': 'key_attr', 'type': str, 'doc': 'the attribute name to use as the key', 'default': 'name'}, - {'name': 'add_callable', 'type': types.FunctionType, - 'doc': 'function to call on an element after adding it to this dict using the add or __setitem__ methods', - 'default': None}, - {'name': 'remove_callable', 'type': types.FunctionType, - 'doc': ('function to call on an element after removing it from this dict using the pop, popitem, clear, ' - 'or __delitem__ methods'), - 'default': None}) + @docval( + { + "name": "label", + "type": str, + "doc": "the label on this dictionary", + }, + { + "name": "key_attr", + "type": str, + "doc": "the attribute name to use as the key", + "default": "name", + }, + { + "name": "add_callable", + "type": types.FunctionType, + "doc": "function to call on an element after adding it to this dict using the add or __setitem__ methods", + "default": None, + }, + { + "name": "remove_callable", + "type": types.FunctionType, + "doc": ( + "function to call on an element after removing it from this dict using" + " the pop, popitem, clear, or __delitem__ methods" + ), + "default": None, + }, + ) def __init__(self, **kwargs): - label, key_attr, add_callable, remove_callable = getargs('label', 'key_attr', 'add_callable', 'remove_callable', - kwargs) - self.__label = label - self.__key_attr = key_attr - self.__add_callable = add_callable - self.__remove_callable = remove_callable + self.__label = kwargs["label"] + self.__key_attr = kwargs["key_attr"] + self.__add_callable = kwargs["add_callable"] + self.__remove_callable = kwargs["remove_callable"] @property def label(self): @@ -995,7 +1093,7 @@ def __getitem__(self, args): returned, not a set. """ key = args - if '==' in args: + if "==" in args: key, val = args.split("==") key = key.strip() val = val.strip() # val is a string @@ -1022,8 +1120,9 @@ def __setitem__(self, key, value): Raises ValueError if value does not have attribute key_attr. """ if key in self: - raise TypeError("Key '%s' is already in this dict. Cannot reset items in a %s." - % (key, self.__class__.__name__)) + raise TypeError( + "Key '%s' is already in this dict. Cannot reset items in a %s." % (key, self.__class__.__name__) + ) self.__check_value(value) if key != getattr(value, self.key_attr): raise KeyError("Key '%s' must equal attribute '%s' of '%s'." % (key, self.key_attr, value)) @@ -1041,8 +1140,10 @@ def add(self, value): def __check_value(self, value): if not hasattr(value, self.key_attr): - raise ValueError("Cannot set value '%s' in %s. Value must have attribute '%s'." - % (value, self.__class__.__name__, self.key_attr)) + raise ValueError( + "Cannot set value '%s' in %s. Value must have attribute '%s'." + % (value, self.__class__.__name__, self.key_attr) + ) def pop(self, k): """Remove an item that matches the key. If remove_callable was initialized, call that on the returned value.""" @@ -1080,17 +1181,18 @@ def __delitem__(self, k): def setdefault(self, k): """setdefault is not supported. A TypeError will be raised.""" - raise TypeError('setdefault is not supported for %s' % self.__class__.__name__) + raise TypeError("setdefault is not supported for %s" % self.__class__.__name__) def update(self, other): """update is not supported. A TypeError will be raised.""" - raise TypeError('update is not supported for %s' % self.__class__.__name__) + raise TypeError("update is not supported for %s" % self.__class__.__name__) -@docval_macro('array_data') +@docval_macro("array_data") class StrDataset(h5py.Dataset): """Wrapper to decode strings on reading the dataset""" - def __init__(self, dset, encoding, errors='strict'): + + def __init__(self, dset, encoding, errors="strict"): self.dset = dset if encoding is None: encoding = h5py.h5t.check_string_dtype(dset.dtype).encoding @@ -1101,7 +1203,7 @@ def __getattr__(self, name): return getattr(self.dset, name) def __repr__(self): - return '' % repr(self.dset)[1:-1] + return "" % repr(self.dset)[1:-1] def __len__(self): return len(self.dset) @@ -1117,6 +1219,7 @@ def __getitem__(self, args): if np.isscalar(bytes_arr): return bytes_arr.decode(self.encoding, self.errors) - return np.array([ - b.decode(self.encoding, self.errors) for b in bytes_arr.flat - ], dtype=object).reshape(bytes_arr.shape) + return np.array( + [b.decode(self.encoding, self.errors) for b in bytes_arr.flat], + dtype=object, + ).reshape(bytes_arr.shape) diff --git a/src/hdmf/validate/__init__.py b/src/hdmf/validate/__init__.py index cb515cea8..902cb3038 100644 --- a/src/hdmf/validate/__init__.py +++ b/src/hdmf/validate/__init__.py @@ -1,3 +1,9 @@ from . import errors from .errors import * # noqa: F403 -from .validator import ValidatorMap, Validator, AttributeValidator, DatasetValidator, GroupValidator +from .validator import ( + AttributeValidator, + DatasetValidator, + GroupValidator, + Validator, + ValidatorMap, +) diff --git a/src/hdmf/validate/errors.py b/src/hdmf/validate/errors.py index fb1bfc1b4..2e982ec6d 100644 --- a/src/hdmf/validate/errors.py +++ b/src/hdmf/validate/errors.py @@ -12,19 +12,20 @@ "MissingDataType", "IllegalLinkError", "IncorrectDataType", - "IncorrectQuantityError" + "IncorrectQuantityError", ] class Error: - - @docval({'name': 'name', 'type': str, 'doc': 'the name of the component that is erroneous'}, - {'name': 'reason', 'type': str, 'doc': 'the reason for the error'}, - {'name': 'location', 'type': str, 'doc': 'the location of the error', 'default': None}) + @docval( + {"name": "name", "type": str, "doc": "the name of the component that is erroneous"}, + {"name": "reason", "type": str, "doc": "the reason for the error"}, + {"name": "location", "type": str, "doc": "the location of the error", "default": None}, + ) def __init__(self, **kwargs): - self.__name = getargs('name', kwargs) - self.__reason = getargs('reason', kwargs) - self.__location = getargs('location', kwargs) + self.__name = getargs("name", kwargs) + self.__reason = getargs("reason", kwargs) + self.__location = getargs("location", kwargs) @property def name(self): @@ -76,7 +77,7 @@ def __equatable_str(self): use the fully-provided name. """ if self.location is not None: - equatable_name = self.name.split('/')[-1] + equatable_name = self.name.split("/")[-1] else: equatable_name = self.name return self.__format_str(equatable_name, self.location, self.reason) @@ -86,45 +87,50 @@ def __eq__(self, other): class DtypeError(Error): - - @docval({'name': 'name', 'type': str, 'doc': 'the name of the component that is erroneous'}, - {'name': 'expected', 'type': (dtype, type, str, list), 'doc': 'the expected dtype'}, - {'name': 'received', 'type': (dtype, type, str, list), 'doc': 'the received dtype'}, - {'name': 'location', 'type': str, 'doc': 'the location of the error', 'default': None}) + @docval( + {"name": "name", "type": str, "doc": "the name of the component that is erroneous"}, + {"name": "expected", "type": (dtype, type, str, list), "doc": "the expected dtype"}, + {"name": "received", "type": (dtype, type, str, list), "doc": "the received dtype"}, + {"name": "location", "type": str, "doc": "the location of the error", "default": None}, + ) def __init__(self, **kwargs): - name = getargs('name', kwargs) - expected = getargs('expected', kwargs) - received = getargs('received', kwargs) + name = getargs("name", kwargs) + expected = getargs("expected", kwargs) + received = getargs("received", kwargs) if isinstance(expected, list): expected = DtypeHelper.simplify_cpd_type(expected) reason = "incorrect type - expected '%s', got '%s'" % (expected, received) - loc = getargs('location', kwargs) + loc = getargs("location", kwargs) super().__init__(name, reason, location=loc) class MissingError(Error): - @docval({'name': 'name', 'type': str, 'doc': 'the name of the component that is erroneous'}, - {'name': 'location', 'type': str, 'doc': 'the location of the error', 'default': None}) + @docval( + {"name": "name", "type": str, "doc": "the name of the component that is erroneous"}, + {"name": "location", "type": str, "doc": "the location of the error", "default": None}, + ) def __init__(self, **kwargs): - name = getargs('name', kwargs) + name = getargs("name", kwargs) reason = "argument missing" - loc = getargs('location', kwargs) + loc = getargs("location", kwargs) super().__init__(name, reason, location=loc) class MissingDataType(Error): - @docval({'name': 'name', 'type': str, 'doc': 'the name of the component that is erroneous'}, - {'name': 'data_type', 'type': str, 'doc': 'the missing data type'}, - {'name': 'location', 'type': str, 'doc': 'the location of the error', 'default': None}, - {'name': 'missing_dt_name', 'type': str, 'doc': 'the name of the missing data type', 'default': None}) + @docval( + {"name": "name", "type": str, "doc": "the name of the component that is erroneous"}, + {"name": "data_type", "type": str, "doc": "the missing data type"}, + {"name": "location", "type": str, "doc": "the location of the error", "default": None}, + {"name": "missing_dt_name", "type": str, "doc": "the name of the missing data type", "default": None}, + ) def __init__(self, **kwargs): - name, data_type, missing_dt_name = getargs('name', 'data_type', 'missing_dt_name', kwargs) + name, data_type, missing_dt_name = getargs("name", "data_type", "missing_dt_name", kwargs) self.__data_type = data_type if missing_dt_name is not None: reason = "missing data type %s (%s)" % (self.__data_type, missing_dt_name) else: reason = "missing data type %s" % self.__data_type - loc = getargs('location', kwargs) + loc = getargs("location", kwargs) super().__init__(name, reason, location=loc) @property @@ -134,45 +140,50 @@ def data_type(self): class IncorrectQuantityError(Error): """A validation error indicating that a child group/dataset/link has the incorrect quantity of matching elements""" - @docval({'name': 'name', 'type': str, 'doc': 'the name of the component that is erroneous'}, - {'name': 'data_type', 'type': str, 'doc': 'the data type which has the incorrect quantity'}, - {'name': 'expected', 'type': (str, int), 'doc': 'the expected quantity'}, - {'name': 'received', 'type': (str, int), 'doc': 'the received quantity'}, - {'name': 'location', 'type': str, 'doc': 'the location of the error', 'default': None}) + + @docval( + {"name": "name", "type": str, "doc": "the name of the component that is erroneous"}, + {"name": "data_type", "type": str, "doc": "the data type which has the incorrect quantity"}, + {"name": "expected", "type": (str, int), "doc": "the expected quantity"}, + {"name": "received", "type": (str, int), "doc": "the received quantity"}, + {"name": "location", "type": str, "doc": "the location of the error", "default": None}, + ) def __init__(self, **kwargs): - name, data_type, expected, received = getargs('name', 'data_type', 'expected', 'received', kwargs) + name, data_type, expected, received = getargs("name", "data_type", "expected", "received", kwargs) reason = "expected a quantity of %s for data type %s, received %s" % (str(expected), data_type, str(received)) - loc = getargs('location', kwargs) + loc = getargs("location", kwargs) super().__init__(name, reason, location=loc) class ExpectedArrayError(Error): - - @docval({'name': 'name', 'type': str, 'doc': 'the name of the component that is erroneous'}, - {'name': 'expected', 'type': (tuple, list), 'doc': 'the expected shape'}, - {'name': 'received', 'type': str, 'doc': 'the received data'}, - {'name': 'location', 'type': str, 'doc': 'the location of the error', 'default': None}) + @docval( + {"name": "name", "type": str, "doc": "the name of the component that is erroneous"}, + {"name": "expected", "type": (tuple, list), "doc": "the expected shape"}, + {"name": "received", "type": str, "doc": "the received data"}, + {"name": "location", "type": str, "doc": "the location of the error", "default": None}, + ) def __init__(self, **kwargs): - name = getargs('name', kwargs) - expected = getargs('expected', kwargs) - received = getargs('received', kwargs) + name = getargs("name", kwargs) + expected = getargs("expected", kwargs) + received = getargs("received", kwargs) reason = "incorrect shape - expected an array of shape '%s', got non-array data '%s'" % (expected, received) - loc = getargs('location', kwargs) + loc = getargs("location", kwargs) super().__init__(name, reason, location=loc) class ShapeError(Error): - - @docval({'name': 'name', 'type': str, 'doc': 'the name of the component that is erroneous'}, - {'name': 'expected', 'type': (tuple, list), 'doc': 'the expected shape'}, - {'name': 'received', 'type': (tuple, list), 'doc': 'the received shape'}, - {'name': 'location', 'type': str, 'doc': 'the location of the error', 'default': None}) + @docval( + {"name": "name", "type": str, "doc": "the name of the component that is erroneous"}, + {"name": "expected", "type": (tuple, list), "doc": "the expected shape"}, + {"name": "received", "type": (tuple, list), "doc": "the received shape"}, + {"name": "location", "type": str, "doc": "the location of the error", "default": None}, + ) def __init__(self, **kwargs): - name = getargs('name', kwargs) - expected = getargs('expected', kwargs) - received = getargs('received', kwargs) + name = getargs("name", kwargs) + expected = getargs("expected", kwargs) + received = getargs("received", kwargs) reason = "incorrect shape - expected '%s', got '%s'" % (expected, received) - loc = getargs('location', kwargs) + loc = getargs("location", kwargs) super().__init__(name, reason, location=loc) @@ -182,12 +193,14 @@ class IllegalLinkError(Error): (i.e. a dataset or a group) must be used """ - @docval({'name': 'name', 'type': str, 'doc': 'the name of the component that is erroneous'}, - {'name': 'location', 'type': str, 'doc': 'the location of the error', 'default': None}) + @docval( + {"name": "name", "type": str, "doc": "the name of the component that is erroneous"}, + {"name": "location", "type": str, "doc": "the location of the error", "default": None}, + ) def __init__(self, **kwargs): - name = getargs('name', kwargs) + name = getargs("name", kwargs) reason = "illegal use of link (linked object will not be validated)" - loc = getargs('location', kwargs) + loc = getargs("location", kwargs) super().__init__(name, reason, location=loc) @@ -196,14 +209,16 @@ class IncorrectDataType(Error): A validation error for indicating that the incorrect data_type (not dtype) was used. """ - @docval({'name': 'name', 'type': str, 'doc': 'the name of the component that is erroneous'}, - {'name': 'expected', 'type': str, 'doc': 'the expected data_type'}, - {'name': 'received', 'type': str, 'doc': 'the received data_type'}, - {'name': 'location', 'type': str, 'doc': 'the location of the error', 'default': None}) + @docval( + {"name": "name", "type": str, "doc": "the name of the component that is erroneous"}, + {"name": "expected", "type": str, "doc": "the expected data_type"}, + {"name": "received", "type": str, "doc": "the received data_type"}, + {"name": "location", "type": str, "doc": "the location of the error", "default": None}, + ) def __init__(self, **kwargs): - name = getargs('name', kwargs) - expected = getargs('expected', kwargs) - received = getargs('received', kwargs) + name = getargs("name", kwargs) + expected = getargs("expected", kwargs) + received = getargs("received", kwargs) reason = "incorrect data_type - expected '%s', got '%s'" % (expected, received) - loc = getargs('location', kwargs) + loc = getargs("location", kwargs) super().__init__(name, reason, location=loc) diff --git a/src/hdmf/validate/validator.py b/src/hdmf/validate/validator.py index 4788d32fa..b1318ee30 100644 --- a/src/hdmf/validate/validator.py +++ b/src/hdmf/validate/validator.py @@ -1,33 +1,54 @@ import re from abc import ABCMeta, abstractmethod +from collections import OrderedDict, defaultdict from copy import copy from itertools import chain -from collections import defaultdict, OrderedDict import numpy as np -from .errors import Error, DtypeError, MissingError, MissingDataType, ShapeError, IllegalLinkError, IncorrectDataType -from .errors import ExpectedArrayError, IncorrectQuantityError -from ..build import GroupBuilder, DatasetBuilder, LinkBuilder, ReferenceBuilder, RegionBuilder +from ..build import ( + DatasetBuilder, + GroupBuilder, + LinkBuilder, + ReferenceBuilder, + RegionBuilder, +) from ..build.builders import BaseBuilder -from ..spec import Spec, AttributeSpec, GroupSpec, DatasetSpec, RefSpec, LinkSpec -from ..spec import SpecNamespace -from ..spec.spec import BaseStorageSpec, DtypeHelper -from ..utils import docval, getargs, pystr, get_data_shape from ..query import ReferenceResolver - +from ..spec import ( + AttributeSpec, + DatasetSpec, + GroupSpec, + LinkSpec, + RefSpec, + Spec, + SpecNamespace, +) +from ..spec.spec import BaseStorageSpec, DtypeHelper +from ..utils import docval, get_data_shape, getargs, pystr +from .errors import ( + DtypeError, + Error, + ExpectedArrayError, + IllegalLinkError, + IncorrectDataType, + IncorrectQuantityError, + MissingDataType, + MissingError, + ShapeError, +) __synonyms = DtypeHelper.primary_dtype_synonyms __additional = { - 'float': ['double'], - 'int8': ['short', 'int', 'long'], - 'short': ['int', 'long'], - 'int': ['long'], - 'uint8': ['uint16', 'uint32', 'uint64'], - 'uint16': ['uint32', 'uint64'], - 'uint32': ['uint64'], - 'utf': ['ascii'] + "float": ["double"], + "int8": ["short", "int", "long"], + "short": ["int", "long"], + "int": ["long"], + "uint8": ["uint16", "uint32", "uint64"], + "uint16": ["uint32", "uint64"], + "uint32": ["uint64"], + "utf": ["ascii"], } # if the spec dtype is a key in __allowable, then all types in __allowable[key] are valid @@ -39,17 +60,17 @@ allow.extend(__synonyms[addl]) for syn in dt_syn: __allowable[syn] = allow -__allowable['numeric'] = set(chain.from_iterable(__allowable[k] for k in __allowable if 'int' in k or 'float' in k)) +__allowable["numeric"] = set(chain.from_iterable(__allowable[k] for k in __allowable if "int" in k or "float" in k)) def check_type(expected, received): - ''' + """ *expected* should come from the spec *received* should come from the data - ''' + """ if isinstance(expected, list): if len(expected) > len(received): - raise ValueError('compound type shorter than expected') + raise ValueError("compound type shorter than expected") for i, exp in enumerate(DtypeHelper.simplify_cpd_type(expected)): rec = received[i] if rec not in __allowable[exp]: @@ -57,16 +78,16 @@ def check_type(expected, received): return True else: if isinstance(received, np.dtype): - if received.char == 'O': - if 'vlen' in received.metadata: - received = received.metadata['vlen'] + if received.char == "O": + if "vlen" in received.metadata: + received = received.metadata["vlen"] else: raise ValueError("Unrecognized type: '%s'" % received) - received = 'utf' if received is str else 'ascii' - elif received.char == 'U': - received = 'utf' - elif received.char == 'S': - received = 'ascii' + received = "utf" if received is str else "ascii" + elif received.char == "U": + received = "utf" + elif received.char == "S": + received = "ascii" else: received = received.name elif isinstance(received, type): @@ -79,8 +100,10 @@ def check_type(expected, received): def get_iso8601_regex(): - isodate_re = (r'^(-?(?:[1-9][0-9]*)?[0-9]{4})-(1[0-2]|0[1-9])-(3[01]|0[1-9]|[12][0-9])T(2[0-3]|[01][0-9]):' - r'([0-5][0-9]):([0-5][0-9])(\.[0-9]+)?(Z|[+-](?:2[0-3]|[01][0-9]):[0-5][0-9])?$') + isodate_re = ( + r"^(-?(?:[1-9][0-9]*)?[0-9]{4})-(1[0-2]|0[1-9])-(3[01]|0[1-9]|[12][0-9])T(2[0-3]|[01][0-9]):" + r"([0-5][0-9]):([0-5][0-9])(\.[0-9]+)?(Z|[+-](?:2[0-3]|[01][0-9]):[0-5][0-9])?$" + ) return re.compile(isodate_re) @@ -90,7 +113,7 @@ def get_iso8601_regex(): def _check_isodatetime(s, default=None): try: if _iso_re.match(pystr(s)) is not None: - return 'isodatetime' + return "isodatetime" except Exception: pass return default @@ -102,13 +125,13 @@ class EmptyArrayError(Exception): def get_type(data): if isinstance(data, str): - return _check_isodatetime(data, 'utf') + return _check_isodatetime(data, "utf") elif isinstance(data, bytes): - return _check_isodatetime(data, 'ascii') + return _check_isodatetime(data, "ascii") elif isinstance(data, RegionBuilder): - return 'region' + return "region" elif isinstance(data, ReferenceBuilder): - return 'object' + return "object" elif isinstance(data, ReferenceResolver): return data.dtype elif isinstance(data, np.ndarray): @@ -116,14 +139,14 @@ def get_type(data): raise EmptyArrayError() return get_type(data[0]) elif isinstance(data, np.bool_): - return 'bool' - if not hasattr(data, '__len__'): + return "bool" + if not hasattr(data, "__len__"): return type(data).__name__ else: - if hasattr(data, 'dtype'): + if hasattr(data, "dtype"): if isinstance(data.dtype, list): return [get_type(data[0][i]) for i in range(len(data.dtype))] - if data.dtype.metadata is not None and data.dtype.metadata.get('vlen') is not None: + if data.dtype.metadata is not None and data.dtype.metadata.get("vlen") is not None: return get_type(data[0]) return data.dtype if len(data) == 0: @@ -159,9 +182,15 @@ def check_shape(expected, received): class ValidatorMap: """A class for keeping track of Validator objects for all data types in a namespace""" - @docval({'name': 'namespace', 'type': SpecNamespace, 'doc': 'the namespace to builder map for'}) + @docval( + { + "name": "namespace", + "type": SpecNamespace, + "doc": "the namespace to builder map for", + } + ) def __init__(self, **kwargs): - ns = getargs('namespace', kwargs) + ns = getargs("namespace", kwargs) self.__ns = ns tree = defaultdict(list) types = ns.get_registered_types() @@ -202,11 +231,14 @@ def __rec(self, tree, node): def namespace(self): return self.__ns - @docval({'name': 'spec', 'type': (Spec, str), 'doc': 'the specification to use to validate'}, - returns='all valid sub data types for the given spec', rtype=tuple) + @docval( + {"name": "spec", "type": (Spec, str), "doc": "the specification to use to validate"}, + returns="all valid sub data types for the given spec", + rtype=tuple, + ) def valid_types(self, **kwargs): - '''Get all valid types for a given data type''' - spec = getargs('spec', kwargs) + """Get all valid types for a given data type""" + spec = getargs("spec", kwargs) if isinstance(spec, Spec): spec = spec.data_type_def try: @@ -214,12 +246,13 @@ def valid_types(self, **kwargs): except KeyError: raise ValueError("no children for '%s'" % spec) - @docval({'name': 'data_type', 'type': (BaseStorageSpec, str), - 'doc': 'the data type to get the validator for'}, - returns='the validator ``data_type``') + @docval( + {"name": "data_type", "type": (BaseStorageSpec, str), "doc": "the data type to get the validator for"}, + returns="the validator ``data_type``", + ) def get_validator(self, **kwargs): """Return the validator for a given data type""" - dt = getargs('data_type', kwargs) + dt = getargs("data_type", kwargs) if isinstance(dt, BaseStorageSpec): dt_tmp = dt.data_type_def if dt_tmp is None: @@ -231,15 +264,18 @@ def get_validator(self, **kwargs): msg = "data type '%s' not found in namespace %s" % (dt, self.__ns.name) raise ValueError(msg) - @docval({'name': 'builder', 'type': BaseBuilder, 'doc': 'the builder to validate'}, - returns="a list of errors found", rtype=list) + @docval( + {"name": "builder", "type": BaseBuilder, "doc": "the builder to validate"}, + returns="a list of errors found", + rtype=list, + ) def validate(self, **kwargs): """Validate a builder against a Spec ``builder`` must have the attribute used to specifying data type by the namespace used to construct this ValidatorMap. """ - builder = getargs('builder', kwargs) + builder = getargs("builder", kwargs) dt = builder.attributes.get(self.__type_key) if dt is None: msg = "builder must have data type defined with attribute '%s'" % self.__type_key @@ -249,13 +285,15 @@ def validate(self, **kwargs): class Validator(metaclass=ABCMeta): - '''A base class for classes that will be used to validate against Spec subclasses''' + """A base class for classes that will be used to validate against Spec subclasses""" - @docval({'name': 'spec', 'type': Spec, 'doc': 'the specification to use to validate'}, - {'name': 'validator_map', 'type': ValidatorMap, 'doc': 'the ValidatorMap to use during validation'}) + @docval( + {"name": "spec", "type": Spec, "doc": "the specification to use to validate"}, + {"name": "validator_map", "type": ValidatorMap, "doc": "the ValidatorMap to use during validation"}, + ) def __init__(self, **kwargs): - self.__spec = getargs('spec', kwargs) - self.__vmap = getargs('validator_map', kwargs) + self.__spec = getargs("spec", kwargs) + self.__vmap = getargs("validator_map", kwargs) @property def spec(self): @@ -266,8 +304,11 @@ def vmap(self): return self.__vmap @abstractmethod - @docval({'name': 'value', 'type': None, 'doc': 'either in the form of a value or a Builder'}, - returns='a list of Errors', rtype=list) + @docval( + {"name": "value", "type": None, "doc": "either in the form of a value or a Builder"}, + returns="a list of Errors", + rtype=list, + ) def validate(self, **kwargs): pass @@ -279,24 +320,29 @@ def get_spec_loc(cls, spec): def get_builder_loc(cls, builder): stack = list() tmp = builder - while tmp is not None and tmp.name != 'root': + while tmp is not None and tmp.name != "root": stack.append(tmp.name) tmp = tmp.parent return "/".join(reversed(stack)) class AttributeValidator(Validator): - '''A class for validating values against AttributeSpecs''' + """A class for validating values against AttributeSpecs""" - @docval({'name': 'spec', 'type': AttributeSpec, 'doc': 'the specification to use to validate'}, - {'name': 'validator_map', 'type': ValidatorMap, 'doc': 'the ValidatorMap to use during validation'}) + @docval( + {"name": "spec", "type": AttributeSpec, "doc": "the specification to use to validate"}, + {"name": "validator_map", "type": ValidatorMap, "doc": "the ValidatorMap to use during validation"}, + ) def __init__(self, **kwargs): super().__init__(**kwargs) - @docval({'name': 'value', 'type': None, 'doc': 'the value to validate'}, - returns='a list of Errors', rtype=list) + @docval( + {"name": "value", "type": None, "doc": "the value to validate"}, + returns="a list of Errors", + rtype=list, + ) def validate(self, **kwargs): - value = getargs('value', kwargs) + value = getargs("value", kwargs) ret = list() spec = self.spec if spec.required and value is None: @@ -306,7 +352,7 @@ def validate(self, **kwargs): ret.append(Error(self.get_spec_loc(spec))) elif isinstance(spec.dtype, RefSpec): if not isinstance(value, BaseBuilder): - expected = '%s reference' % spec.dtype.reftype + expected = "%s reference" % spec.dtype.reftype try: value_type = get_type(value) ret.append(DtypeError(self.get_spec_loc(spec), expected, value_type)) @@ -318,7 +364,13 @@ def validate(self, **kwargs): data_type = value.attributes.get(target_spec.type_key()) hierarchy = self.vmap.namespace.catalog.get_hierarchy(data_type) if spec.dtype.target_type not in hierarchy: - ret.append(IncorrectDataType(self.get_spec_loc(spec), spec.dtype.target_type, data_type)) + ret.append( + IncorrectDataType( + self.get_spec_loc(spec), + spec.dtype.target_type, + data_type, + ) + ) else: try: dtype = get_type(value) @@ -330,35 +382,50 @@ def validate(self, **kwargs): shape = get_data_shape(value) if not check_shape(spec.shape, shape): if shape is None: - ret.append(ExpectedArrayError(self.get_spec_loc(self.spec), self.spec.shape, str(value))) + ret.append( + ExpectedArrayError( + self.get_spec_loc(self.spec), + self.spec.shape, + str(value), + ) + ) else: ret.append(ShapeError(self.get_spec_loc(spec), spec.shape, shape)) return ret class BaseStorageValidator(Validator): - '''A base class for validating against Spec objects that have attributes i.e. BaseStorageSpec''' + """A base class for validating against Spec objects that have attributes i.e. BaseStorageSpec""" - @docval({'name': 'spec', 'type': BaseStorageSpec, 'doc': 'the specification to use to validate'}, - {'name': 'validator_map', 'type': ValidatorMap, 'doc': 'the ValidatorMap to use during validation'}) + @docval( + {"name": "spec", "type": BaseStorageSpec, "doc": "the specification to use to validate"}, + {"name": "validator_map", "type": ValidatorMap, "doc": "the ValidatorMap to use during validation"}, + ) def __init__(self, **kwargs): super().__init__(**kwargs) self.__attribute_validators = dict() for attr in self.spec.attributes: self.__attribute_validators[attr.name] = AttributeValidator(attr, self.vmap) - @docval({"name": "builder", "type": BaseBuilder, "doc": "the builder to validate"}, - returns='a list of Errors', rtype=list) + @docval( + {"name": "builder", "type": BaseBuilder, "doc": "the builder to validate"}, + returns="a list of Errors", + rtype=list, + ) def validate(self, **kwargs): - builder = getargs('builder', kwargs) + builder = getargs("builder", kwargs) attributes = builder.attributes ret = list() for attr, validator in self.__attribute_validators.items(): attr_val = attributes.get(attr) if attr_val is None: if validator.spec.required: - ret.append(MissingError(self.get_spec_loc(validator.spec), - location=self.get_builder_loc(builder))) + ret.append( + MissingError( + self.get_spec_loc(validator.spec), + location=self.get_builder_loc(builder), + ) + ) else: errors = validator.validate(attr_val) for err in errors: @@ -368,36 +435,59 @@ def validate(self, **kwargs): class DatasetValidator(BaseStorageValidator): - '''A class for validating DatasetBuilders against DatasetSpecs''' + """A class for validating DatasetBuilders against DatasetSpecs""" - @docval({'name': 'spec', 'type': DatasetSpec, 'doc': 'the specification to use to validate'}, - {'name': 'validator_map', 'type': ValidatorMap, 'doc': 'the ValidatorMap to use during validation'}) + @docval( + {"name": "spec", "type": DatasetSpec, "doc": "the specification to use to validate"}, + {"name": "validator_map", "type": ValidatorMap, "doc": "the ValidatorMap to use during validation"}, + ) def __init__(self, **kwargs): super().__init__(**kwargs) - @docval({"name": "builder", "type": DatasetBuilder, "doc": "the builder to validate"}, - returns='a list of Errors', rtype=list) + @docval( + {"name": "builder", "type": DatasetBuilder, "doc": "the builder to validate"}, + returns="a list of Errors", + rtype=list, + ) def validate(self, **kwargs): - builder = getargs('builder', kwargs) + builder = getargs("builder", kwargs) ret = super().validate(builder) data = builder.data if self.spec.dtype is not None: try: dtype = get_type(data) if not check_type(self.spec.dtype, dtype): - ret.append(DtypeError(self.get_spec_loc(self.spec), self.spec.dtype, dtype, - location=self.get_builder_loc(builder))) + ret.append( + DtypeError( + self.get_spec_loc(self.spec), + self.spec.dtype, + dtype, + location=self.get_builder_loc(builder), + ) + ) except EmptyArrayError: # do not validate dtype of empty array. HDMF does not yet set dtype when writing a list/tuple pass shape = get_data_shape(data) if not check_shape(self.spec.shape, shape): if shape is None: - ret.append(ExpectedArrayError(self.get_spec_loc(self.spec), self.spec.shape, str(data), - location=self.get_builder_loc(builder))) + ret.append( + ExpectedArrayError( + self.get_spec_loc(self.spec), + self.spec.shape, + str(data), + location=self.get_builder_loc(builder), + ) + ) else: - ret.append(ShapeError(self.get_spec_loc(self.spec), self.spec.shape, shape, - location=self.get_builder_loc(builder))) + ret.append( + ShapeError( + self.get_spec_loc(self.spec), + self.spec.shape, + shape, + location=self.get_builder_loc(builder), + ) + ) return ret @@ -408,17 +498,22 @@ def _resolve_data_type(spec): class GroupValidator(BaseStorageValidator): - '''A class for validating GroupBuilders against GroupSpecs''' + """A class for validating GroupBuilders against GroupSpecs""" - @docval({'name': 'spec', 'type': GroupSpec, 'doc': 'the specification to use to validate'}, - {'name': 'validator_map', 'type': ValidatorMap, 'doc': 'the ValidatorMap to use during validation'}) + @docval( + {"name": "spec", "type": GroupSpec, "doc": "the specification to use to validate"}, + {"name": "validator_map", "type": ValidatorMap, "doc": "the ValidatorMap to use during validation"}, + ) def __init__(self, **kwargs): super().__init__(**kwargs) - @docval({"name": "builder", "type": GroupBuilder, "doc": "the builder to validate"}, # noqa: C901 - returns='a list of Errors', rtype=list) - def validate(self, **kwargs): # noqa: C901 - builder = getargs('builder', kwargs) + @docval( + {"name": "builder", "type": GroupBuilder, "doc": "the builder to validate"}, + returns="a list of Errors", + rtype=list, + ) + def validate(self, **kwargs): + builder = getargs("builder", kwargs) errors = super().validate(builder) errors.extend(self.__validate_children(builder)) return self._remove_duplicates(errors) @@ -438,9 +533,11 @@ def __validate_children(self, parent_builder): spec_children = chain(self.spec.datasets, self.spec.groups, self.spec.links) matcher = SpecMatcher(self.vmap, spec_children) - builder_children = chain(parent_builder.datasets.values(), - parent_builder.groups.values(), - parent_builder.links.values()) + builder_children = chain( + parent_builder.datasets.values(), + parent_builder.groups.values(), + parent_builder.links.values(), + ) matcher.assign_to_specs(builder_children) for child_spec, matched_builders in matcher.spec_matches: @@ -465,8 +562,12 @@ def __construct_missing_child_error(self, child_spec, parent_builder): builder_loc = self.get_builder_loc(parent_builder) if data_type is not None: name_of_erroneous = self.get_spec_loc(self.spec) - return MissingDataType(name_of_erroneous, data_type, - location=builder_loc, missing_dt_name=child_spec.name) + return MissingDataType( + name_of_erroneous, + data_type, + location=builder_loc, + missing_dt_name=child_spec.name, + ) else: name_of_erroneous = self.get_spec_loc(child_spec) return MissingError(name_of_erroneous, location=builder_loc) @@ -484,8 +585,13 @@ def __construct_incorrect_quantity_error(self, child_spec, parent_builder, n_bui name_of_erroneous = self.get_spec_loc(self.spec) data_type = _resolve_data_type(child_spec) builder_loc = self.get_builder_loc(parent_builder) - return IncorrectQuantityError(name_of_erroneous, data_type, expected=child_spec.quantity, - received=n_builders, location=builder_loc) + return IncorrectQuantityError( + name_of_erroneous, + data_type, + expected=child_spec.quantity, + received=n_builders, + location=builder_loc, + ) def __validate_child_builder(self, child_spec, child_builder, parent_builder): """Validate a child builder against a child spec considering links""" @@ -634,6 +740,7 @@ def _filter_by_name(self, candidates, builder): """Returns the candidate specs that either have the same name as the builder or do not specify a name. """ + def name_is_consistent(spec_matches): spec = spec_matches.spec return spec.name is None or spec.name == builder.name @@ -644,6 +751,7 @@ def _filter_by_type(self, candidates, builder): """Returns the candidate specs which have a data type consistent with the builder's data type. """ + def compatible_type(spec_matches): spec = spec_matches.spec if isinstance(spec, LinkSpec): @@ -665,6 +773,7 @@ def _filter_by_unsatisfied(self, candidates): """Returns the candidate specs which are not yet matched against a number of builders which fulfils the quantity for the spec. """ + def is_unsatisfied(spec_matches): spec = spec_matches.spec n_match = len(spec_matches.builders) diff --git a/tests/unit/back_compat_tests/test_1_1_0.py b/tests/unit/back_compat_tests/test_1_1_0.py index b21cc3ae7..d63f3840d 100644 --- a/tests/unit/back_compat_tests/test_1_1_0.py +++ b/tests/unit/back_compat_tests/test_1_1_0.py @@ -2,16 +2,16 @@ from shutil import copyfile from hdmf.backends.hdf5.h5tools import HDF5IO -from tests.unit.helpers.utils import Foo, FooBucket, get_foo_buildmanager from hdmf.testing import TestCase +from ..helpers.utils import Foo, FooBucket, get_foo_buildmanager -class Test1_1_0(TestCase): +class Test1_1_0(TestCase): def setUp(self): # created using manager in test_io_hdf5_h5tools - self.orig_1_0_5 = 'tests/unit/back_compat_tests/1.0.5.h5' - self.path_1_0_5 = 'test_1.0.5.h5' + self.orig_1_0_5 = "tests/unit/back_compat_tests/1.0.5.h5" + self.path_1_0_5 = "test_1.0.5.h5" copyfile(self.orig_1_0_5, self.path_1_0_5) # note: this may break if the current manager is different from the old manager @@ -23,23 +23,32 @@ def tearDown(self): os.remove(self.path_1_0_5) def test_read_1_0_5(self): - '''Test whether we can read files made by hdmf version 1.0.5''' - with HDF5IO(self.path_1_0_5, manager=self.manager, mode='r') as io: + """Test whether we can read files made by hdmf version 1.0.5""" + with HDF5IO(self.path_1_0_5, manager=self.manager, mode="r") as io: read_foofile = io.read() self.assertTrue(len(read_foofile.buckets) == 1) - self.assertListEqual(read_foofile.buckets['test_bucket'].foos['foo1'].my_data[:].tolist(), [0, 1, 2, 3, 4]) - self.assertListEqual(read_foofile.buckets['test_bucket'].foos['foo2'].my_data[:].tolist(), [5, 6, 7, 8, 9]) + self.assertListEqual( + read_foofile.buckets["test_bucket"].foos["foo1"].my_data[:].tolist(), + [0, 1, 2, 3, 4], + ) + self.assertListEqual( + read_foofile.buckets["test_bucket"].foos["foo2"].my_data[:].tolist(), + [5, 6, 7, 8, 9], + ) def test_append_1_0_5(self): - '''Test whether we can append to files made by hdmf version 1.0.5''' - foo = Foo('foo3', [10, 20, 30, 40, 50], "I am foo3", 17, 3.14) - foobucket = FooBucket('foobucket2', [foo]) + """Test whether we can append to files made by hdmf version 1.0.5""" + foo = Foo("foo3", [10, 20, 30, 40, 50], "I am foo3", 17, 3.14) + foobucket = FooBucket("foobucket2", [foo]) - with HDF5IO(self.path_1_0_5, manager=self.manager, mode='a') as io: + with HDF5IO(self.path_1_0_5, manager=self.manager, mode="a") as io: read_foofile = io.read() read_foofile.add_bucket(foobucket) io.write(read_foofile) - with HDF5IO(self.path_1_0_5, manager=self.manager, mode='r') as io: + with HDF5IO(self.path_1_0_5, manager=self.manager, mode="r") as io: read_foofile = io.read() - self.assertListEqual(read_foofile.buckets['foobucket2'].foos['foo3'].my_data[:].tolist(), foo.my_data) + self.assertListEqual( + read_foofile.buckets["foobucket2"].foos["foo3"].my_data[:].tolist(), + foo.my_data, + ) diff --git a/tests/unit/build_tests/mapper_tests/test_build.py b/tests/unit/build_tests/mapper_tests/test_build.py index 8590f29f2..4767c5988 100644 --- a/tests/unit/build_tests/mapper_tests/test_build.py +++ b/tests/unit/build_tests/mapper_tests/test_build.py @@ -1,36 +1,46 @@ from abc import ABCMeta, abstractmethod import numpy as np + from hdmf import Container, Data -from hdmf.build import ObjectMapper, BuildManager, TypeMap, GroupBuilder, DatasetBuilder +from hdmf.build import BuildManager, DatasetBuilder, GroupBuilder, ObjectMapper, TypeMap from hdmf.build.warnings import DtypeConversionWarning -from hdmf.spec import GroupSpec, AttributeSpec, DatasetSpec, SpecCatalog, SpecNamespace, NamespaceCatalog, Spec +from hdmf.spec import ( + AttributeSpec, + DatasetSpec, + GroupSpec, + NamespaceCatalog, + Spec, + SpecCatalog, + SpecNamespace, +) from hdmf.testing import TestCase from hdmf.utils import docval, getargs -from tests.unit.helpers.utils import CORE_NAMESPACE - +from ...helpers.utils import CORE_NAMESPACE # TODO: test build of extended group/dataset that modifies an attribute dtype (commented out below), shape, value, etc. # by restriction. also check that attributes cannot be deleted or scope expanded. # TODO: test build of extended dataset that modifies shape by restriction. -class Bar(Container): - @docval({'name': 'name', 'type': str, 'doc': 'the name of this Bar'}, - {'name': 'attr1', 'type': str, 'doc': 'a string attribute'}, - {'name': 'attr2', 'type': 'int', 'doc': 'an int attribute', 'default': None}, - {'name': 'ext_attr', 'type': bool, 'doc': 'a boolean attribute', 'default': True}) +class Bar(Container): + @docval( + {"name": "name", "type": str, "doc": "the name of this Bar"}, + {"name": "attr1", "type": str, "doc": "a string attribute"}, + {"name": "attr2", "type": "int", "doc": "an int attribute", "default": None}, + {"name": "ext_attr", "type": bool, "doc": "a boolean attribute", "default": True}, + ) def __init__(self, **kwargs): - name, attr1, attr2, ext_attr = getargs('name', 'attr1', 'attr2', 'ext_attr', kwargs) + name, attr1, attr2, ext_attr = getargs("name", "attr1", "attr2", "ext_attr", kwargs) super().__init__(name=name) self.__attr1 = attr1 self.__attr2 = attr2 - self.__ext_attr = kwargs['ext_attr'] + self.__ext_attr = kwargs["ext_attr"] @property def data_type(self): - return 'Bar' + return "Bar" @property def attr1(self): @@ -46,11 +56,12 @@ def ext_attr(self): class BarHolder(Container): - - @docval({'name': 'name', 'type': str, 'doc': 'the name of this BarHolder'}, - {'name': 'bars', 'type': ('data', 'array_data'), 'doc': 'bars', 'default': list()}) + @docval( + {"name": "name", "type": str, "doc": "the name of this BarHolder"}, + {"name": "bars", "type": ("data", "array_data"), "doc": "bars", "default": list()}, + ) def __init__(self, **kwargs): - name, bars = getargs('name', 'bars', kwargs) + name, bars = getargs("name", "bars", kwargs) super().__init__(name=name) self.__bars = bars for b in bars: @@ -59,7 +70,7 @@ def __init__(self, **kwargs): @property def data_type(self): - return 'BarHolder' + return "BarHolder" @property def bars(self): @@ -67,59 +78,59 @@ def bars(self): class ExtBarMapper(ObjectMapper): - - @docval({"name": "spec", "type": Spec, "doc": "the spec to get the attribute value for"}, - {"name": "container", "type": Bar, "doc": "the container to get the attribute value from"}, - {"name": "manager", "type": BuildManager, "doc": "the BuildManager used for managing this build"}, - returns='the value of the attribute') + @docval( + {"name": "spec", "type": Spec, "doc": "the spec to get the attribute value for"}, + {"name": "container", "type": Bar, "doc": "the container to get the attribute value from"}, + {"name": "manager", "type": BuildManager, "doc": "the BuildManager used for managing this build"}, + returns="the value of the attribute", + ) def get_attr_value(self, **kwargs): - ''' Get the value of the attribute corresponding to this spec from the given container ''' - spec, container, manager = getargs('spec', 'container', 'manager', kwargs) + """Get the value of the attribute corresponding to this spec from the given container""" + spec, container, manager = getargs("spec", "container", "manager", kwargs) # handle custom mapping of field 'ext_attr' within container BarHolder/Bar -> spec BarHolder/Bar.ext_attr if isinstance(container.parent, BarHolder): - if spec.name == 'ext_attr': + if spec.name == "ext_attr": return container.ext_attr return super().get_attr_value(**kwargs) class BuildGroupExtAttrsMixin(TestCase, metaclass=ABCMeta): - def setUp(self): self.setUpBarSpec() self.setUpBarHolderSpec() spec_catalog = SpecCatalog() - spec_catalog.register_spec(self.bar_spec, 'test.yaml') - spec_catalog.register_spec(self.bar_holder_spec, 'test.yaml') + spec_catalog.register_spec(self.bar_spec, "test.yaml") + spec_catalog.register_spec(self.bar_holder_spec, "test.yaml") namespace = SpecNamespace( - doc='a test namespace', + doc="a test namespace", name=CORE_NAMESPACE, - schema=[{'source': 'test.yaml'}], - version='0.1.0', - catalog=spec_catalog + schema=[{"source": "test.yaml"}], + version="0.1.0", + catalog=spec_catalog, ) namespace_catalog = NamespaceCatalog() namespace_catalog.add_namespace(CORE_NAMESPACE, namespace) type_map = TypeMap(namespace_catalog) - type_map.register_container_type(CORE_NAMESPACE, 'Bar', Bar) - type_map.register_container_type(CORE_NAMESPACE, 'BarHolder', BarHolder) + type_map.register_container_type(CORE_NAMESPACE, "Bar", Bar) + type_map.register_container_type(CORE_NAMESPACE, "BarHolder", BarHolder) type_map.register_map(Bar, ExtBarMapper) type_map.register_map(BarHolder, ObjectMapper) self.manager = BuildManager(type_map) def setUpBarSpec(self): attr1_attr = AttributeSpec( - name='attr1', - dtype='text', - doc='an example string attribute', + name="attr1", + dtype="text", + doc="an example string attribute", ) attr2_attr = AttributeSpec( - name='attr2', - dtype='int', - doc='an example int attribute', + name="attr2", + dtype="int", + doc="an example int attribute", ) self.bar_spec = GroupSpec( - doc='A test group specification with a data type', - data_type_def='Bar', + doc="A test group specification with a data type", + data_type_def="Bar", attributes=[attr1_attr, attr2_attr], ) @@ -138,19 +149,19 @@ class TestBuildGroupAddedAttr(BuildGroupExtAttrsMixin, TestCase): def setUpBarHolderSpec(self): ext_attr = AttributeSpec( - name='ext_attr', - dtype='bool', - doc='A boolean attribute', + name="ext_attr", + dtype="bool", + doc="A boolean attribute", ) bar_ext_no_name_spec = GroupSpec( - doc='A Bar extended with attribute ext_attr', - data_type_inc='Bar', - quantity='*', + doc="A Bar extended with attribute ext_attr", + data_type_inc="Bar", + quantity="*", attributes=[ext_attr], ) self.bar_holder_spec = GroupSpec( - doc='A container of multiple extended Bar objects', - data_type_def='BarHolder', + doc="A container of multiple extended Bar objects", + data_type_def="BarHolder", groups=[bar_ext_no_name_spec], ) @@ -159,39 +170,39 @@ def test_build_added_attr(self): Test build of BarHolder which can contain multiple extended Bar objects, which have a new attribute. """ ext_bar_inst = Bar( - name='my_bar', - attr1='a string', + name="my_bar", + attr1="a string", attr2=10, ext_attr=False, ) bar_holder_inst = BarHolder( - name='my_bar_holder', + name="my_bar_holder", bars=[ext_bar_inst], ) expected_inner = GroupBuilder( - name='my_bar', + name="my_bar", attributes={ - 'attr1': 'a string', - 'attr2': 10, - 'data_type': 'Bar', - 'ext_attr': False, - 'namespace': CORE_NAMESPACE, - 'object_id': ext_bar_inst.object_id, + "attr1": "a string", + "attr2": 10, + "data_type": "Bar", + "ext_attr": False, + "namespace": CORE_NAMESPACE, + "object_id": ext_bar_inst.object_id, }, ) expected = GroupBuilder( - name='my_bar_holder', - groups={'my_bar': expected_inner}, + name="my_bar_holder", + groups={"my_bar": expected_inner}, attributes={ - 'data_type': 'BarHolder', - 'namespace': CORE_NAMESPACE, - 'object_id': bar_holder_inst.object_id, + "data_type": "BarHolder", + "namespace": CORE_NAMESPACE, + "object_id": bar_holder_inst.object_id, }, ) # the object mapper automatically maps the spec of extended Bars to the 'BarMapper.bars' field - builder = self.manager.build(bar_holder_inst, source='test.h5') + builder = self.manager.build(bar_holder_inst, source="test.h5") self.assertDictEqual(builder, expected) @@ -205,19 +216,19 @@ class TestBuildGroupRefinedAttr(BuildGroupExtAttrsMixin, TestCase): def setUpBarHolderSpec(self): int_attr2 = AttributeSpec( - name='attr2', - dtype='int64', - doc='Refine Bar spec from int to int64', + name="attr2", + dtype="int64", + doc="Refine Bar spec from int to int64", ) bar_ext_no_name_spec = GroupSpec( - doc='A Bar extended with modified attribute attr2', - data_type_inc='Bar', - quantity='*', + doc="A Bar extended with modified attribute attr2", + data_type_inc="Bar", + quantity="*", attributes=[int_attr2], ) self.bar_holder_spec = GroupSpec( - doc='A container of multiple extended Bar objects', - data_type_def='BarHolder', + doc="A container of multiple extended Bar objects", + data_type_def="BarHolder", groups=[bar_ext_no_name_spec], ) @@ -226,37 +237,37 @@ def test_build_refined_attr(self): Test build of BarHolder which can contain multiple extended Bar objects, which have a modified attr2. """ ext_bar_inst = Bar( - name='my_bar', - attr1='a string', + name="my_bar", + attr1="a string", attr2=np.int64(10), ) bar_holder_inst = BarHolder( - name='my_bar_holder', + name="my_bar_holder", bars=[ext_bar_inst], ) expected_inner = GroupBuilder( - name='my_bar', + name="my_bar", attributes={ - 'attr1': 'a string', - 'attr2': np.int64(10), - 'data_type': 'Bar', - 'namespace': CORE_NAMESPACE, - 'object_id': ext_bar_inst.object_id, - } + "attr1": "a string", + "attr2": np.int64(10), + "data_type": "Bar", + "namespace": CORE_NAMESPACE, + "object_id": ext_bar_inst.object_id, + }, ) expected = GroupBuilder( - name='my_bar_holder', - groups={'my_bar': expected_inner}, + name="my_bar_holder", + groups={"my_bar": expected_inner}, attributes={ - 'data_type': 'BarHolder', - 'namespace': CORE_NAMESPACE, - 'object_id': bar_holder_inst.object_id, + "data_type": "BarHolder", + "namespace": CORE_NAMESPACE, + "object_id": bar_holder_inst.object_id, }, ) # the object mapper automatically maps the spec of extended Bars to the 'BarMapper.bars' field - builder = self.manager.build(bar_holder_inst, source='test.h5') + builder = self.manager.build(bar_holder_inst, source="test.h5") self.assertDictEqual(builder, expected) # def test_build_refined_attr_wrong_type(self): @@ -301,22 +312,23 @@ def test_build_refined_attr(self): class BarData(Data): - - @docval({'name': 'name', 'type': str, 'doc': 'the name of this BarData'}, - {'name': 'data', 'type': ('data', 'array_data'), 'doc': 'the data'}, - {'name': 'attr1', 'type': str, 'doc': 'a string attribute'}, - {'name': 'attr2', 'type': 'int', 'doc': 'an int attribute', 'default': None}, - {'name': 'ext_attr', 'type': bool, 'doc': 'a boolean attribute', 'default': True}) + @docval( + {"name": "name", "type": str, "doc": "the name of this BarData"}, + {"name": "data", "type": ("data", "array_data"), "doc": "the data"}, + {"name": "attr1", "type": str, "doc": "a string attribute"}, + {"name": "attr2", "type": "int", "doc": "an int attribute", "default": None}, + {"name": "ext_attr", "type": bool, "doc": "a boolean attribute", "default": True}, + ) def __init__(self, **kwargs): - name, data, attr1, attr2, ext_attr = getargs('name', 'data', 'attr1', 'attr2', 'ext_attr', kwargs) + name, data, attr1, attr2, ext_attr = getargs("name", "data", "attr1", "attr2", "ext_attr", kwargs) super().__init__(name=name, data=data) self.__attr1 = attr1 self.__attr2 = attr2 - self.__ext_attr = kwargs['ext_attr'] + self.__ext_attr = kwargs["ext_attr"] @property def data_type(self): - return 'BarData' + return "BarData" @property def attr1(self): @@ -332,11 +344,12 @@ def ext_attr(self): class BarDataHolder(Container): - - @docval({'name': 'name', 'type': str, 'doc': 'the name of this BarDataHolder'}, - {'name': 'bar_datas', 'type': ('data', 'array_data'), 'doc': 'bar_datas', 'default': list()}) + @docval( + {"name": "name", "type": str, "doc": "the name of this BarDataHolder"}, + {"name": "bar_datas", "type": ("data", "array_data"), "doc": "bar_datas", "default": list()}, + ) def __init__(self, **kwargs): - name, bar_datas = getargs('name', 'bar_datas', kwargs) + name, bar_datas = getargs("name", "bar_datas", kwargs) super().__init__(name=name) self.__bar_datas = bar_datas for b in bar_datas: @@ -345,7 +358,7 @@ def __init__(self, **kwargs): @property def data_type(self): - return 'BarDataHolder' + return "BarDataHolder" @property def bar_datas(self): @@ -353,66 +366,66 @@ def bar_datas(self): class ExtBarDataMapper(ObjectMapper): - - @docval({"name": "spec", "type": Spec, "doc": "the spec to get the attribute value for"}, - {"name": "container", "type": BarData, "doc": "the container to get the attribute value from"}, - {"name": "manager", "type": BuildManager, "doc": "the BuildManager used for managing this build"}, - returns='the value of the attribute') + @docval( + {"name": "spec", "type": Spec, "doc": "the spec to get the attribute value for"}, + {"name": "container", "type": BarData, "doc": "the container to get the attribute value from"}, + {"name": "manager", "type": BuildManager, "doc": "the BuildManager used for managing this build"}, + returns="the value of the attribute", + ) def get_attr_value(self, **kwargs): - ''' Get the value of the attribute corresponding to this spec from the given container ''' - spec, container, manager = getargs('spec', 'container', 'manager', kwargs) + """Get the value of the attribute corresponding to this spec from the given container""" + spec, container, manager = getargs("spec", "container", "manager", kwargs) # handle custom mapping of field 'ext_attr' within container # BardataHolder/BarData -> spec BarDataHolder/BarData.ext_attr if isinstance(container.parent, BarDataHolder): - if spec.name == 'ext_attr': + if spec.name == "ext_attr": return container.ext_attr return super().get_attr_value(**kwargs) class BuildDatasetExtAttrsMixin(TestCase, metaclass=ABCMeta): - def setUp(self): self.set_up_specs() spec_catalog = SpecCatalog() - spec_catalog.register_spec(self.bar_data_spec, 'test.yaml') - spec_catalog.register_spec(self.bar_data_holder_spec, 'test.yaml') + spec_catalog.register_spec(self.bar_data_spec, "test.yaml") + spec_catalog.register_spec(self.bar_data_holder_spec, "test.yaml") namespace = SpecNamespace( - doc='a test namespace', + doc="a test namespace", name=CORE_NAMESPACE, - schema=[{'source': 'test.yaml'}], - version='0.1.0', - catalog=spec_catalog + schema=[{"source": "test.yaml"}], + version="0.1.0", + catalog=spec_catalog, ) namespace_catalog = NamespaceCatalog() namespace_catalog.add_namespace(CORE_NAMESPACE, namespace) type_map = TypeMap(namespace_catalog) - type_map.register_container_type(CORE_NAMESPACE, 'BarData', BarData) - type_map.register_container_type(CORE_NAMESPACE, 'BarDataHolder', BarDataHolder) + type_map.register_container_type(CORE_NAMESPACE, "BarData", BarData) + type_map.register_container_type(CORE_NAMESPACE, "BarDataHolder", BarDataHolder) type_map.register_map(BarData, ExtBarDataMapper) type_map.register_map(BarDataHolder, ObjectMapper) self.manager = BuildManager(type_map) def set_up_specs(self): attr1_attr = AttributeSpec( - name='attr1', - dtype='text', - doc='an example string attribute', + name="attr1", + dtype="text", + doc="an example string attribute", ) attr2_attr = AttributeSpec( - name='attr2', - dtype='int', - doc='an example int attribute', + name="attr2", + dtype="int", + doc="an example int attribute", ) self.bar_data_spec = DatasetSpec( - doc='A test dataset specification with a data type', - data_type_def='BarData', - dtype='int', + doc="A test dataset specification with a data type", + data_type_def="BarData", + dtype="int", shape=[[None], [None, None]], attributes=[attr1_attr, attr2_attr], ) self.bar_data_holder_spec = GroupSpec( - doc='A container of multiple extended BarData objects', - data_type_def='BarDataHolder', + doc="A container of multiple extended BarData objects", + data_type_def="BarDataHolder", datasets=[self.get_refined_bar_data_spec()], ) @@ -432,14 +445,14 @@ class TestBuildDatasetAddedAttrs(BuildDatasetExtAttrsMixin, TestCase): def get_refined_bar_data_spec(self): ext_attr = AttributeSpec( - name='ext_attr', - dtype='bool', - doc='A boolean attribute', + name="ext_attr", + dtype="bool", + doc="A boolean attribute", ) refined_spec = DatasetSpec( - doc='A BarData extended with attribute ext_attr', - data_type_inc='BarData', - quantity='*', + doc="A BarData extended with attribute ext_attr", + data_type_inc="BarData", + quantity="*", attributes=[ext_attr], ) return refined_spec @@ -449,41 +462,41 @@ def test_build_added_attr(self): Test build of BarHolder which can contain multiple extended BarData objects, which have a new attribute. """ ext_bar_data_inst = BarData( - name='my_bar', + name="my_bar", data=list(range(10)), - attr1='a string', + attr1="a string", attr2=10, ext_attr=False, ) bar_data_holder_inst = BarDataHolder( - name='my_bar_holder', + name="my_bar_holder", bar_datas=[ext_bar_data_inst], ) expected_inner = DatasetBuilder( - name='my_bar', + name="my_bar", data=list(range(10)), attributes={ - 'attr1': 'a string', - 'attr2': 10, - 'data_type': 'BarData', - 'ext_attr': False, - 'namespace': CORE_NAMESPACE, - 'object_id': ext_bar_data_inst.object_id, + "attr1": "a string", + "attr2": 10, + "data_type": "BarData", + "ext_attr": False, + "namespace": CORE_NAMESPACE, + "object_id": ext_bar_data_inst.object_id, }, ) expected = GroupBuilder( - name='my_bar_holder', - datasets={'my_bar': expected_inner}, + name="my_bar_holder", + datasets={"my_bar": expected_inner}, attributes={ - 'data_type': 'BarDataHolder', - 'namespace': CORE_NAMESPACE, - 'object_id': bar_data_holder_inst.object_id, + "data_type": "BarDataHolder", + "namespace": CORE_NAMESPACE, + "object_id": bar_data_holder_inst.object_id, }, ) # the object mapper automatically maps the spec of extended Bars to the 'BarMapper.bars' field - builder = self.manager.build(bar_data_holder_inst, source='test.h5') + builder = self.manager.build(bar_data_holder_inst, source="test.h5") self.assertDictEqual(builder, expected) @@ -498,10 +511,10 @@ class TestBuildDatasetRefinedDtype(BuildDatasetExtAttrsMixin, TestCase): def get_refined_bar_data_spec(self): refined_spec = DatasetSpec( - doc='A BarData with refined int64 dtype', - data_type_inc='BarData', - dtype='int64', - quantity='*', + doc="A BarData with refined int64 dtype", + data_type_inc="BarData", + dtype="int64", + quantity="*", ) return refined_spec @@ -510,44 +523,46 @@ def test_build_refined_dtype_convert(self): Test build of BarDataHolder which contains a BarData with data that needs to be converted to the refined dtype. """ ext_bar_data_inst = BarData( - name='my_bar', + name="my_bar", data=np.array([1, 2], dtype=np.int32), # the refined spec says data should be int64s - attr1='a string', + attr1="a string", attr2=10, ) bar_data_holder_inst = BarDataHolder( - name='my_bar_holder', + name="my_bar_holder", bar_datas=[ext_bar_data_inst], ) expected_inner = DatasetBuilder( - name='my_bar', + name="my_bar", data=np.array([1, 2], dtype=np.int64), # the objectmapper should convert the given data to int64s attributes={ - 'attr1': 'a string', - 'attr2': 10, - 'data_type': 'BarData', - 'namespace': CORE_NAMESPACE, - 'object_id': ext_bar_data_inst.object_id, + "attr1": "a string", + "attr2": 10, + "data_type": "BarData", + "namespace": CORE_NAMESPACE, + "object_id": ext_bar_data_inst.object_id, }, ) expected = GroupBuilder( - name='my_bar_holder', - datasets={'my_bar': expected_inner}, + name="my_bar_holder", + datasets={"my_bar": expected_inner}, attributes={ - 'data_type': 'BarDataHolder', - 'namespace': CORE_NAMESPACE, - 'object_id': bar_data_holder_inst.object_id, + "data_type": "BarDataHolder", + "namespace": CORE_NAMESPACE, + "object_id": bar_data_holder_inst.object_id, }, ) # the object mapper automatically maps the spec of extended Bars to the 'BarMapper.bars' field - msg = ("Spec 'BarDataHolder/BarData': Value with data type int32 is being converted to data type int64 " - "as specified.") + msg = ( + "Spec 'BarDataHolder/BarData': Value with data type int32 is being" + " converted to data type int64 as specified." + ) with self.assertWarnsWith(DtypeConversionWarning, msg): - builder = self.manager.build(bar_data_holder_inst, source='test.h5') - np.testing.assert_array_equal(builder.datasets['my_bar'].data, expected.datasets['my_bar'].data) - self.assertEqual(builder.datasets['my_bar'].data.dtype, np.int64) + builder = self.manager.build(bar_data_holder_inst, source="test.h5") + np.testing.assert_array_equal(builder.datasets["my_bar"].data, expected.datasets["my_bar"].data) + self.assertEqual(builder.datasets["my_bar"].data.dtype, np.int64) class TestBuildDatasetNotRefinedDtype(BuildDatasetExtAttrsMixin, TestCase): @@ -561,9 +576,9 @@ class TestBuildDatasetNotRefinedDtype(BuildDatasetExtAttrsMixin, TestCase): def get_refined_bar_data_spec(self): refined_spec = DatasetSpec( - doc='A BarData', - data_type_inc='BarData', - quantity='*', + doc="A BarData", + data_type_inc="BarData", + quantity="*", ) return refined_spec @@ -572,39 +587,39 @@ def test_build_correct_dtype(self): Test build of BarDataHolder which contains a BarData. """ ext_bar_data_inst = BarData( - name='my_bar', + name="my_bar", data=[1, 2], - attr1='a string', + attr1="a string", attr2=10, ) bar_data_holder_inst = BarDataHolder( - name='my_bar_holder', + name="my_bar_holder", bar_datas=[ext_bar_data_inst], ) expected_inner = DatasetBuilder( - name='my_bar', + name="my_bar", data=[1, 2], attributes={ - 'attr1': 'a string', - 'attr2': 10, - 'data_type': 'BarData', - 'namespace': CORE_NAMESPACE, - 'object_id': ext_bar_data_inst.object_id, + "attr1": "a string", + "attr2": 10, + "data_type": "BarData", + "namespace": CORE_NAMESPACE, + "object_id": ext_bar_data_inst.object_id, }, ) expected = GroupBuilder( - name='my_bar_holder', - datasets={'my_bar': expected_inner}, + name="my_bar_holder", + datasets={"my_bar": expected_inner}, attributes={ - 'data_type': 'BarDataHolder', - 'namespace': CORE_NAMESPACE, - 'object_id': bar_data_holder_inst.object_id, + "data_type": "BarDataHolder", + "namespace": CORE_NAMESPACE, + "object_id": bar_data_holder_inst.object_id, }, ) # the object mapper automatically maps the spec of extended Bars to the 'BarMapper.bars' field - builder = self.manager.build(bar_data_holder_inst, source='test.h5') + builder = self.manager.build(bar_data_holder_inst, source="test.h5") self.assertDictEqual(builder, expected) def test_build_incorrect_dtype(self): @@ -612,17 +627,17 @@ def test_build_incorrect_dtype(self): Test build of BarDataHolder which contains a BarData """ ext_bar_data_inst = BarData( - name='my_bar', - data=['a', 'b'], - attr1='a string', + name="my_bar", + data=["a", "b"], + attr1="a string", attr2=10, ) bar_data_holder_inst = BarDataHolder( - name='my_bar_holder', + name="my_bar_holder", bar_datas=[ext_bar_data_inst], ) # the object mapper automatically maps the spec of extended Bars to the 'BarMapper.bars' field msg = "could not resolve dtype for BarData 'my_bar'" with self.assertRaisesWith(Exception, msg): - self.manager.build(bar_data_holder_inst, source='test.h5') + self.manager.build(bar_data_holder_inst, source="test.h5") diff --git a/tests/unit/build_tests/mapper_tests/test_build_quantity.py b/tests/unit/build_tests/mapper_tests/test_build_quantity.py index 797c8a6bf..33d4b81db 100644 --- a/tests/unit/build_tests/mapper_tests/test_build_quantity.py +++ b/tests/unit/build_tests/mapper_tests/test_build_quantity.py @@ -1,13 +1,27 @@ from hdmf import Container, Data -from hdmf.build import (BuildManager, TypeMap, GroupBuilder, DatasetBuilder, LinkBuilder, ObjectMapper, - MissingRequiredBuildWarning, IncorrectQuantityBuildWarning) -from hdmf.spec import GroupSpec, DatasetSpec, LinkSpec, SpecCatalog, SpecNamespace, NamespaceCatalog -from hdmf.spec.spec import ZERO_OR_MANY, ONE_OR_MANY, ZERO_OR_ONE, DEF_QUANTITY +from hdmf.build import ( + BuildManager, + DatasetBuilder, + GroupBuilder, + IncorrectQuantityBuildWarning, + LinkBuilder, + MissingRequiredBuildWarning, + ObjectMapper, + TypeMap, +) +from hdmf.spec import ( + DatasetSpec, + GroupSpec, + LinkSpec, + NamespaceCatalog, + SpecCatalog, + SpecNamespace, +) +from hdmf.spec.spec import DEF_QUANTITY, ONE_OR_MANY, ZERO_OR_MANY, ZERO_OR_ONE from hdmf.testing import TestCase from hdmf.utils import docval, getargs -from tests.unit.helpers.utils import CORE_NAMESPACE - +from ...helpers.utils import CORE_NAMESPACE ########################## # test all crosses: @@ -39,13 +53,14 @@ class NotSimpleQux(Data): class SimpleBucket(Container): - - @docval({'name': 'name', 'type': str, 'doc': 'the name of this SimpleBucket'}, - {'name': 'foos', 'type': list, 'doc': 'the SimpleFoo objects', 'default': list()}, - {'name': 'quxs', 'type': list, 'doc': 'the SimpleQux objects', 'default': list()}, - {'name': 'links', 'type': list, 'doc': 'another way to store SimpleFoo objects', 'default': list()}) + @docval( + {"name": "name", "type": str, "doc": "the name of this SimpleBucket"}, + {"name": "foos", "type": list, "doc": "the SimpleFoo objects", "default": list()}, + {"name": "quxs", "type": list, "doc": "the SimpleQux objects", "default": list()}, + {"name": "links", "type": list, "doc": "another way to store SimpleFoo objects", "default": list()}, + ) def __init__(self, **kwargs): - name, foos, quxs, links = getargs('name', 'foos', 'quxs', 'links', kwargs) + name, foos, quxs, links = getargs("name", "foos", "quxs", "links", kwargs) super().__init__(name=name) # note: collections of groups are unordered in HDF5, so make these dictionaries for keyed access self.foos = {f.name: f for f in foos} @@ -60,18 +75,29 @@ def __init__(self, **kwargs): class BasicBucket(Container): - - @docval({'name': 'name', 'type': str, 'doc': 'the name of this BasicBucket'}, - {'name': 'untyped_dataset', 'type': 'scalar_data', - 'doc': 'a scalar dataset within this BasicBucket', 'default': None}, - {'name': 'untyped_array_dataset', 'type': ('data', 'array_data'), - 'doc': 'an array dataset within this BasicBucket', 'default': None},) + @docval( + { + "name": "name", + "type": str, + "doc": "the name of this BasicBucket", + }, + { + "name": "untyped_dataset", + "type": "scalar_data", + "doc": "a scalar dataset within this BasicBucket", + "default": None, + }, + { + "name": "untyped_array_dataset", + "type": ("data", "array_data"), + "doc": "an array dataset within this BasicBucket", + "default": None, + }, + ) def __init__(self, **kwargs): - name, untyped_dataset, untyped_array_dataset = getargs('name', 'untyped_dataset', 'untyped_array_dataset', - kwargs) - super().__init__(name=name) - self.untyped_dataset = untyped_dataset - self.untyped_array_dataset = untyped_array_dataset + super().__init__(name=kwargs["name"]) + self.untyped_dataset = kwargs["untyped_dataset"] + self.untyped_array_dataset = kwargs["untyped_array_dataset"] class BuildQuantityMixin: @@ -79,24 +105,24 @@ class BuildQuantityMixin: def setUpManager(self, specs): spec_catalog = SpecCatalog() - schema_file = 'test.yaml' + schema_file = "test.yaml" for s in specs: spec_catalog.register_spec(s, schema_file) namespace = SpecNamespace( - doc='a test namespace', + doc="a test namespace", name=CORE_NAMESPACE, - schema=[{'source': schema_file}], - version='0.1.0', - catalog=spec_catalog + schema=[{"source": schema_file}], + version="0.1.0", + catalog=spec_catalog, ) namespace_catalog = NamespaceCatalog() namespace_catalog.add_namespace(CORE_NAMESPACE, namespace) type_map = TypeMap(namespace_catalog) - type_map.register_container_type(CORE_NAMESPACE, 'SimpleFoo', SimpleFoo) - type_map.register_container_type(CORE_NAMESPACE, 'NotSimpleFoo', NotSimpleFoo) - type_map.register_container_type(CORE_NAMESPACE, 'SimpleQux', SimpleQux) - type_map.register_container_type(CORE_NAMESPACE, 'NotSimpleQux', NotSimpleQux) - type_map.register_container_type(CORE_NAMESPACE, 'SimpleBucket', SimpleBucket) + type_map.register_container_type(CORE_NAMESPACE, "SimpleFoo", SimpleFoo) + type_map.register_container_type(CORE_NAMESPACE, "NotSimpleFoo", NotSimpleFoo) + type_map.register_container_type(CORE_NAMESPACE, "SimpleQux", SimpleQux) + type_map.register_container_type(CORE_NAMESPACE, "NotSimpleQux", NotSimpleQux) + type_map.register_container_type(CORE_NAMESPACE, "SimpleBucket", SimpleBucket) type_map.register_map(SimpleBucket, self.setUpBucketMapper()) self.manager = BuildManager(type_map) @@ -105,79 +131,82 @@ def _create_builder(self, container): if isinstance(container, Container): ret = GroupBuilder( name=container.name, - attributes={'namespace': container.namespace, - 'data_type': container.data_type, - 'object_id': container.object_id} + attributes={ + "namespace": container.namespace, + "data_type": container.data_type, + "object_id": container.object_id, + }, ) else: ret = DatasetBuilder( name=container.name, data=container.data, - attributes={'namespace': container.namespace, - 'data_type': container.data_type, - 'object_id': container.object_id} + attributes={ + "namespace": container.namespace, + "data_type": container.data_type, + "object_id": container.object_id, + }, ) return ret class TypeIncUntypedGroupMixin: - def create_specs(self, quantity): # Type SimpleBucket contains: # - an untyped group "foo_holder" which contains [quantity] groups of data_type_inc SimpleFoo # - an untyped group "qux_holder" which contains [quantity] datasets of data_type_inc SimpleQux # - an untyped group "link_holder" which contains [quantity] links of target_type SimpleFoo foo_spec = GroupSpec( - doc='A test group specification with a data type', - data_type_def='SimpleFoo', + doc="A test group specification with a data type", + data_type_def="SimpleFoo", ) not_foo_spec = GroupSpec( - doc='A test group specification with a data type', - data_type_def='NotSimpleFoo', + doc="A test group specification with a data type", + data_type_def="NotSimpleFoo", ) qux_spec = DatasetSpec( - doc='A test group specification with a data type', - data_type_def='SimpleQux', + doc="A test group specification with a data type", + data_type_def="SimpleQux", ) not_qux_spec = DatasetSpec( - doc='A test group specification with a data type', - data_type_def='NotSimpleQux', + doc="A test group specification with a data type", + data_type_def="NotSimpleQux", ) foo_inc_spec = GroupSpec( - doc='the SimpleFoos in this bucket', - data_type_inc='SimpleFoo', - quantity=quantity + doc="the SimpleFoos in this bucket", + data_type_inc="SimpleFoo", + quantity=quantity, ) foo_holder_spec = GroupSpec( - doc='An untyped subgroup for SimpleFoos', - name='foo_holder', - groups=[foo_inc_spec] + doc="An untyped subgroup for SimpleFoos", + name="foo_holder", + groups=[foo_inc_spec], ) qux_inc_spec = DatasetSpec( - doc='the SimpleQuxs in this bucket', - data_type_inc='SimpleQux', - quantity=quantity + doc="the SimpleQuxs in this bucket", + data_type_inc="SimpleQux", + quantity=quantity, ) qux_holder_spec = GroupSpec( - doc='An untyped subgroup for SimpleQuxs', - name='qux_holder', - datasets=[qux_inc_spec] + doc="An untyped subgroup for SimpleQuxs", + name="qux_holder", + datasets=[qux_inc_spec], ) foo_link_spec = LinkSpec( - doc='the links in this bucket', - target_type='SimpleFoo', - quantity=quantity + doc="the links in this bucket", + target_type="SimpleFoo", + quantity=quantity, ) link_holder_spec = GroupSpec( - doc='An untyped subgroup for links', - name='link_holder', - links=[foo_link_spec] + doc="An untyped subgroup for links", + name="link_holder", + links=[foo_link_spec], ) bucket_spec = GroupSpec( - doc='A test group specification for a data type containing data type', + doc="A test group specification for a data type containing data type", name="test_bucket", - data_type_def='SimpleBucket', - groups=[foo_holder_spec, qux_holder_spec, link_holder_spec] + data_type_def="SimpleBucket", + groups=[foo_holder_spec, qux_holder_spec, link_holder_spec], ) return [foo_spec, not_foo_spec, qux_spec, not_qux_spec, bucket_spec] @@ -185,164 +214,148 @@ def setUpBucketMapper(self): class BucketMapper(ObjectMapper): def __init__(self, spec): super().__init__(spec) - self.unmap(spec.get_group('foo_holder')) - self.map_spec('foos', spec.get_group('foo_holder').get_data_type('SimpleFoo')) - self.unmap(spec.get_group('qux_holder')) - self.map_spec('quxs', spec.get_group('qux_holder').get_data_type('SimpleQux')) - self.unmap(spec.get_group('link_holder')) - self.map_spec('links', spec.get_group('link_holder').links[0]) + self.unmap(spec.get_group("foo_holder")) + self.map_spec( + "foos", + spec.get_group("foo_holder").get_data_type("SimpleFoo"), + ) + self.unmap(spec.get_group("qux_holder")) + self.map_spec( + "quxs", + spec.get_group("qux_holder").get_data_type("SimpleQux"), + ) + self.unmap(spec.get_group("link_holder")) + self.map_spec("links", spec.get_group("link_holder").links[0]) return BucketMapper def get_two_bucket_test(self): - foos = [SimpleFoo('my_foo1'), SimpleFoo('my_foo2')] - quxs = [SimpleQux('my_qux1', data=[1, 2, 3]), SimpleQux('my_qux2', data=[4, 5, 6])] - bucket = SimpleBucket( - name='test_bucket', - foos=foos, - quxs=quxs, - links=foos - ) - foo1_builder = self._create_builder(bucket.foos['my_foo1']) - foo2_builder = self._create_builder(bucket.foos['my_foo2']) - qux1_builder = self._create_builder(bucket.quxs['my_qux1']) - qux2_builder = self._create_builder(bucket.quxs['my_qux2']) + foos = [SimpleFoo("my_foo1"), SimpleFoo("my_foo2")] + quxs = [ + SimpleQux("my_qux1", data=[1, 2, 3]), + SimpleQux("my_qux2", data=[4, 5, 6]), + ] + bucket = SimpleBucket(name="test_bucket", foos=foos, quxs=quxs, links=foos) + foo1_builder = self._create_builder(bucket.foos["my_foo1"]) + foo2_builder = self._create_builder(bucket.foos["my_foo2"]) + qux1_builder = self._create_builder(bucket.quxs["my_qux1"]) + qux2_builder = self._create_builder(bucket.quxs["my_qux2"]) foo_holder_builder = GroupBuilder( - name='foo_holder', - groups={'my_foo1': foo1_builder, - 'my_foo2': foo2_builder} + name="foo_holder", + groups={"my_foo1": foo1_builder, "my_foo2": foo2_builder}, ) qux_holder_builder = GroupBuilder( - name='qux_holder', - datasets={'my_qux1': qux1_builder, - 'my_qux2': qux2_builder} + name="qux_holder", + datasets={"my_qux1": qux1_builder, "my_qux2": qux2_builder}, ) foo1_link_builder = LinkBuilder(builder=foo1_builder) foo2_link_builder = LinkBuilder(builder=foo2_builder) link_holder_builder = GroupBuilder( - name='link_holder', - links={'my_foo1': foo1_link_builder, - 'my_foo2': foo2_link_builder} + name="link_holder", + links={"my_foo1": foo1_link_builder, "my_foo2": foo2_link_builder}, ) bucket_builder = GroupBuilder( - name='test_bucket', - groups={'foos': foo_holder_builder, - 'quxs': qux_holder_builder, - 'links': link_holder_builder}, - attributes={'namespace': CORE_NAMESPACE, - 'data_type': 'SimpleBucket', - 'object_id': bucket.object_id} + name="test_bucket", + groups={ + "foos": foo_holder_builder, + "quxs": qux_holder_builder, + "links": link_holder_builder, + }, + attributes={ + "namespace": CORE_NAMESPACE, + "data_type": "SimpleBucket", + "object_id": bucket.object_id, + }, ) return bucket, bucket_builder def get_one_bucket_test(self): - foos = [SimpleFoo('my_foo1')] - quxs = [SimpleQux('my_qux1', data=[1, 2, 3])] - bucket = SimpleBucket( - name='test_bucket', - foos=foos, - quxs=quxs, - links=foos - ) + foos = [SimpleFoo("my_foo1")] + quxs = [SimpleQux("my_qux1", data=[1, 2, 3])] + bucket = SimpleBucket(name="test_bucket", foos=foos, quxs=quxs, links=foos) foo1_builder = GroupBuilder( - name='my_foo1', - attributes={'namespace': CORE_NAMESPACE, - 'data_type': 'SimpleFoo', - 'object_id': bucket.foos['my_foo1'].object_id} - ) - foo_holder_builder = GroupBuilder( - name='foo_holder', - groups={'my_foo1': foo1_builder} - ) + name="my_foo1", + attributes={ + "namespace": CORE_NAMESPACE, + "data_type": "SimpleFoo", + "object_id": bucket.foos["my_foo1"].object_id, + }, + ) + foo_holder_builder = GroupBuilder(name="foo_holder", groups={"my_foo1": foo1_builder}) qux1_builder = DatasetBuilder( - name='my_qux1', + name="my_qux1", data=[1, 2, 3], - attributes={'namespace': CORE_NAMESPACE, - 'data_type': 'SimpleQux', - 'object_id': bucket.quxs['my_qux1'].object_id} - ) - qux_holder_builder = GroupBuilder( - name='qux_holder', - datasets={'my_qux1': qux1_builder} + attributes={ + "namespace": CORE_NAMESPACE, + "data_type": "SimpleQux", + "object_id": bucket.quxs["my_qux1"].object_id, + }, ) + qux_holder_builder = GroupBuilder(name="qux_holder", datasets={"my_qux1": qux1_builder}) foo1_link_builder = LinkBuilder(builder=foo1_builder) - link_holder_builder = GroupBuilder( - name='link_holder', - links={'my_foo1': foo1_link_builder} - ) + link_holder_builder = GroupBuilder(name="link_holder", links={"my_foo1": foo1_link_builder}) bucket_builder = GroupBuilder( - name='test_bucket', - groups={'foos': foo_holder_builder, - 'quxs': qux_holder_builder, - 'links': link_holder_builder}, - attributes={'namespace': CORE_NAMESPACE, - 'data_type': 'SimpleBucket', - 'object_id': bucket.object_id} + name="test_bucket", + groups={ + "foos": foo_holder_builder, + "quxs": qux_holder_builder, + "links": link_holder_builder, + }, + attributes={ + "namespace": CORE_NAMESPACE, + "data_type": "SimpleBucket", + "object_id": bucket.object_id, + }, ) return bucket, bucket_builder def get_zero_bucket_test(self): - bucket = SimpleBucket( - name='test_bucket' - ) - foo_holder_builder = GroupBuilder( - name='foo_holder', - groups={} - ) - qux_holder_builder = GroupBuilder( - name='qux_holder', - datasets={} - ) - link_holder_builder = GroupBuilder( - name='link_holder', - links={} - ) + bucket = SimpleBucket(name="test_bucket") + foo_holder_builder = GroupBuilder(name="foo_holder", groups={}) + qux_holder_builder = GroupBuilder(name="qux_holder", datasets={}) + link_holder_builder = GroupBuilder(name="link_holder", links={}) bucket_builder = GroupBuilder( - name='test_bucket', - groups={'foos': foo_holder_builder, - 'quxs': qux_holder_builder, - 'links': link_holder_builder}, - attributes={'namespace': CORE_NAMESPACE, - 'data_type': 'SimpleBucket', - 'object_id': bucket.object_id} + name="test_bucket", + groups={ + "foos": foo_holder_builder, + "quxs": qux_holder_builder, + "links": link_holder_builder, + }, + attributes={ + "namespace": CORE_NAMESPACE, + "data_type": "SimpleBucket", + "object_id": bucket.object_id, + }, ) return bucket, bucket_builder def get_mismatch_bucket_test(self): - foos = [NotSimpleFoo('my_foo1'), NotSimpleFoo('my_foo2')] - quxs = [NotSimpleQux('my_qux1', data=[1, 2, 3]), NotSimpleQux('my_qux2', data=[4, 5, 6])] - bucket = SimpleBucket( - name='test_bucket', - foos=foos, - quxs=quxs, - links=foos - ) - foo_holder_builder = GroupBuilder( - name='foo_holder', - groups={} - ) - qux_holder_builder = GroupBuilder( - name='qux_holder', - datasets={} - ) - link_holder_builder = GroupBuilder( - name='link_holder', - links={} - ) + foos = [NotSimpleFoo("my_foo1"), NotSimpleFoo("my_foo2")] + quxs = [ + NotSimpleQux("my_qux1", data=[1, 2, 3]), + NotSimpleQux("my_qux2", data=[4, 5, 6]), + ] + bucket = SimpleBucket(name="test_bucket", foos=foos, quxs=quxs, links=foos) + foo_holder_builder = GroupBuilder(name="foo_holder", groups={}) + qux_holder_builder = GroupBuilder(name="qux_holder", datasets={}) + link_holder_builder = GroupBuilder(name="link_holder", links={}) bucket_builder = GroupBuilder( - name='test_bucket', - groups={'foos': foo_holder_builder, - 'quxs': qux_holder_builder, - 'links': link_holder_builder}, - attributes={'namespace': CORE_NAMESPACE, - 'data_type': 'SimpleBucket', - 'object_id': bucket.object_id} + name="test_bucket", + groups={ + "foos": foo_holder_builder, + "quxs": qux_holder_builder, + "links": link_holder_builder, + }, + attributes={ + "namespace": CORE_NAMESPACE, + "data_type": "SimpleBucket", + "object_id": bucket.object_id, + }, ) return bucket, bucket_builder class TypeDefMixin: - def create_specs(self, quantity): # Type SimpleBucket contains: # - contains [quantity] groups of data_type_def SimpleFoo @@ -350,29 +363,29 @@ def create_specs(self, quantity): # NOTE: links do not have data_type_def, so leave them out of these tests # NOTE: nested type definitions are strongly discouraged now foo_spec = GroupSpec( - doc='the SimpleFoos in this bucket', - data_type_def='SimpleFoo', - quantity=quantity + doc="the SimpleFoos in this bucket", + data_type_def="SimpleFoo", + quantity=quantity, ) qux_spec = DatasetSpec( - doc='the SimpleQuxs in this bucket', - data_type_def='SimpleQux', - quantity=quantity + doc="the SimpleQuxs in this bucket", + data_type_def="SimpleQux", + quantity=quantity, ) not_foo_spec = GroupSpec( - doc='A test group specification with a data type', - data_type_def='NotSimpleFoo', + doc="A test group specification with a data type", + data_type_def="NotSimpleFoo", ) not_qux_spec = DatasetSpec( - doc='A test group specification with a data type', - data_type_def='NotSimpleQux', + doc="A test group specification with a data type", + data_type_def="NotSimpleQux", ) bucket_spec = GroupSpec( - doc='A test group specification for a data type containing data type', + doc="A test group specification for a data type containing data type", name="test_bucket", - data_type_def='SimpleBucket', + data_type_def="SimpleBucket", groups=[foo_spec], - datasets=[qux_spec] + datasets=[qux_spec], ) return [foo_spec, not_foo_spec, qux_spec, not_qux_spec, bucket_spec] @@ -380,129 +393,138 @@ def setUpBucketMapper(self): class BucketMapper(ObjectMapper): def __init__(self, spec): super().__init__(spec) - self.map_spec('foos', spec.get_data_type('SimpleFoo')) - self.map_spec('quxs', spec.get_data_type('SimpleQux')) + self.map_spec("foos", spec.get_data_type("SimpleFoo")) + self.map_spec("quxs", spec.get_data_type("SimpleQux")) return BucketMapper def get_two_bucket_test(self): - foos = [SimpleFoo('my_foo1'), SimpleFoo('my_foo2')] - quxs = [SimpleQux('my_qux1', data=[1, 2, 3]), SimpleQux('my_qux2', data=[4, 5, 6])] + foos = [SimpleFoo("my_foo1"), SimpleFoo("my_foo2")] + quxs = [ + SimpleQux("my_qux1", data=[1, 2, 3]), + SimpleQux("my_qux2", data=[4, 5, 6]), + ] bucket = SimpleBucket( - name='test_bucket', + name="test_bucket", foos=foos, quxs=quxs, ) - foo1_builder = self._create_builder(bucket.foos['my_foo1']) - foo2_builder = self._create_builder(bucket.foos['my_foo2']) - qux1_builder = self._create_builder(bucket.quxs['my_qux1']) - qux2_builder = self._create_builder(bucket.quxs['my_qux2']) + foo1_builder = self._create_builder(bucket.foos["my_foo1"]) + foo2_builder = self._create_builder(bucket.foos["my_foo2"]) + qux1_builder = self._create_builder(bucket.quxs["my_qux1"]) + qux2_builder = self._create_builder(bucket.quxs["my_qux2"]) bucket_builder = GroupBuilder( - name='test_bucket', - groups={'my_foo1': foo1_builder, - 'my_foo2': foo2_builder}, - datasets={'my_qux1': qux1_builder, - 'my_qux2': qux2_builder}, - attributes={'namespace': CORE_NAMESPACE, - 'data_type': 'SimpleBucket', - 'object_id': bucket.object_id} + name="test_bucket", + groups={"my_foo1": foo1_builder, "my_foo2": foo2_builder}, + datasets={"my_qux1": qux1_builder, "my_qux2": qux2_builder}, + attributes={ + "namespace": CORE_NAMESPACE, + "data_type": "SimpleBucket", + "object_id": bucket.object_id, + }, ) return bucket, bucket_builder def get_one_bucket_test(self): - foos = [SimpleFoo('my_foo1')] - quxs = [SimpleQux('my_qux1', data=[1, 2, 3])] + foos = [SimpleFoo("my_foo1")] + quxs = [SimpleQux("my_qux1", data=[1, 2, 3])] bucket = SimpleBucket( - name='test_bucket', + name="test_bucket", foos=foos, quxs=quxs, ) - foo1_builder = self._create_builder(bucket.foos['my_foo1']) - qux1_builder = self._create_builder(bucket.quxs['my_qux1']) + foo1_builder = self._create_builder(bucket.foos["my_foo1"]) + qux1_builder = self._create_builder(bucket.quxs["my_qux1"]) bucket_builder = GroupBuilder( - name='test_bucket', - groups={'my_foo1': foo1_builder}, - datasets={'my_qux1': qux1_builder}, - attributes={'namespace': CORE_NAMESPACE, - 'data_type': 'SimpleBucket', - 'object_id': bucket.object_id} + name="test_bucket", + groups={"my_foo1": foo1_builder}, + datasets={"my_qux1": qux1_builder}, + attributes={ + "namespace": CORE_NAMESPACE, + "data_type": "SimpleBucket", + "object_id": bucket.object_id, + }, ) return bucket, bucket_builder def get_zero_bucket_test(self): - bucket = SimpleBucket( - name='test_bucket' - ) + bucket = SimpleBucket(name="test_bucket") bucket_builder = GroupBuilder( - name='test_bucket', - attributes={'namespace': CORE_NAMESPACE, - 'data_type': 'SimpleBucket', - 'object_id': bucket.object_id} + name="test_bucket", + attributes={ + "namespace": CORE_NAMESPACE, + "data_type": "SimpleBucket", + "object_id": bucket.object_id, + }, ) return bucket, bucket_builder def get_mismatch_bucket_test(self): - foos = [NotSimpleFoo('my_foo1'), NotSimpleFoo('my_foo2')] - quxs = [NotSimpleQux('my_qux1', data=[1, 2, 3]), NotSimpleQux('my_qux2', data=[4, 5, 6])] + foos = [NotSimpleFoo("my_foo1"), NotSimpleFoo("my_foo2")] + quxs = [ + NotSimpleQux("my_qux1", data=[1, 2, 3]), + NotSimpleQux("my_qux2", data=[4, 5, 6]), + ] bucket = SimpleBucket( - name='test_bucket', + name="test_bucket", foos=foos, quxs=quxs, ) bucket_builder = GroupBuilder( - name='test_bucket', - attributes={'namespace': CORE_NAMESPACE, - 'data_type': 'SimpleBucket', - 'object_id': bucket.object_id} + name="test_bucket", + attributes={ + "namespace": CORE_NAMESPACE, + "data_type": "SimpleBucket", + "object_id": bucket.object_id, + }, ) return bucket, bucket_builder class TypeIncMixin: - def create_specs(self, quantity): # Type SimpleBucket contains: # - [quantity] groups of data_type_inc SimpleFoo # - [quantity] datasets of data_type_inc SimpleQux # - [quantity] links of target_type SimpleFoo foo_spec = GroupSpec( - doc='A test group specification with a data type', - data_type_def='SimpleFoo', + doc="A test group specification with a data type", + data_type_def="SimpleFoo", ) not_foo_spec = GroupSpec( - doc='A test group specification with a data type', - data_type_def='NotSimpleFoo', + doc="A test group specification with a data type", + data_type_def="NotSimpleFoo", ) qux_spec = DatasetSpec( - doc='A test group specification with a data type', - data_type_def='SimpleQux', + doc="A test group specification with a data type", + data_type_def="SimpleQux", ) not_qux_spec = DatasetSpec( - doc='A test group specification with a data type', - data_type_def='NotSimpleQux', + doc="A test group specification with a data type", + data_type_def="NotSimpleQux", ) foo_inc_spec = GroupSpec( - doc='the SimpleFoos in this bucket', - data_type_inc='SimpleFoo', - quantity=quantity + doc="the SimpleFoos in this bucket", + data_type_inc="SimpleFoo", + quantity=quantity, ) qux_inc_spec = DatasetSpec( - doc='the SimpleQuxs in this bucket', - data_type_inc='SimpleQux', - quantity=quantity + doc="the SimpleQuxs in this bucket", + data_type_inc="SimpleQux", + quantity=quantity, ) foo_link_spec = LinkSpec( - doc='the links in this bucket', - target_type='SimpleFoo', - quantity=quantity + doc="the links in this bucket", + target_type="SimpleFoo", + quantity=quantity, ) bucket_spec = GroupSpec( - doc='A test group specification for a data type containing data type', + doc="A test group specification for a data type containing data type", name="test_bucket", - data_type_def='SimpleBucket', + data_type_def="SimpleBucket", groups=[foo_inc_spec], datasets=[qux_inc_spec], - links=[foo_link_spec] + links=[foo_link_spec], ) return [foo_spec, not_foo_spec, qux_spec, not_qux_spec, bucket_spec] @@ -510,104 +532,97 @@ def setUpBucketMapper(self): class BucketMapper(ObjectMapper): def __init__(self, spec): super().__init__(spec) - self.map_spec('foos', spec.get_data_type('SimpleFoo')) - self.map_spec('quxs', spec.get_data_type('SimpleQux')) - self.map_spec('links', spec.links[0]) + self.map_spec("foos", spec.get_data_type("SimpleFoo")) + self.map_spec("quxs", spec.get_data_type("SimpleQux")) + self.map_spec("links", spec.links[0]) return BucketMapper def get_two_bucket_test(self): - foos = [SimpleFoo('my_foo1'), SimpleFoo('my_foo2')] - quxs = [SimpleQux('my_qux1', data=[1, 2, 3]), SimpleQux('my_qux2', data=[4, 5, 6])] + foos = [SimpleFoo("my_foo1"), SimpleFoo("my_foo2")] + quxs = [ + SimpleQux("my_qux1", data=[1, 2, 3]), + SimpleQux("my_qux2", data=[4, 5, 6]), + ] # NOTE: unlike in the other tests, links cannot map to the same foos in bucket because of a name clash - links = [SimpleFoo('my_foo3'), SimpleFoo('my_foo4')] - bucket = SimpleBucket( - name='test_bucket', - foos=foos, - quxs=quxs, - links=links - ) - foo1_builder = self._create_builder(bucket.foos['my_foo1']) - foo2_builder = self._create_builder(bucket.foos['my_foo2']) - foo3_builder = self._create_builder(bucket.links['my_foo3']) - foo4_builder = self._create_builder(bucket.links['my_foo4']) - qux1_builder = self._create_builder(bucket.quxs['my_qux1']) - qux2_builder = self._create_builder(bucket.quxs['my_qux2']) + links = [SimpleFoo("my_foo3"), SimpleFoo("my_foo4")] + bucket = SimpleBucket(name="test_bucket", foos=foos, quxs=quxs, links=links) + foo1_builder = self._create_builder(bucket.foos["my_foo1"]) + foo2_builder = self._create_builder(bucket.foos["my_foo2"]) + foo3_builder = self._create_builder(bucket.links["my_foo3"]) + foo4_builder = self._create_builder(bucket.links["my_foo4"]) + qux1_builder = self._create_builder(bucket.quxs["my_qux1"]) + qux2_builder = self._create_builder(bucket.quxs["my_qux2"]) foo3_link_builder = LinkBuilder(builder=foo3_builder) foo4_link_builder = LinkBuilder(builder=foo4_builder) bucket_builder = GroupBuilder( - name='test_bucket', - groups={'my_foo1': foo1_builder, - 'my_foo2': foo2_builder}, - datasets={'my_qux1': qux1_builder, - 'my_qux2': qux2_builder}, - links={'my_foo3': foo3_link_builder, - 'my_foo4': foo4_link_builder}, - attributes={'namespace': CORE_NAMESPACE, - 'data_type': 'SimpleBucket', - 'object_id': bucket.object_id} + name="test_bucket", + groups={"my_foo1": foo1_builder, "my_foo2": foo2_builder}, + datasets={"my_qux1": qux1_builder, "my_qux2": qux2_builder}, + links={"my_foo3": foo3_link_builder, "my_foo4": foo4_link_builder}, + attributes={ + "namespace": CORE_NAMESPACE, + "data_type": "SimpleBucket", + "object_id": bucket.object_id, + }, ) return bucket, bucket_builder def get_one_bucket_test(self): - foos = [SimpleFoo('my_foo1')] - quxs = [SimpleQux('my_qux1', data=[1, 2, 3])] + foos = [SimpleFoo("my_foo1")] + quxs = [SimpleQux("my_qux1", data=[1, 2, 3])] # NOTE: unlike in the other tests, links cannot map to the same foos in bucket because of a name clash - links = [SimpleFoo('my_foo3')] - bucket = SimpleBucket( - name='test_bucket', - foos=foos, - quxs=quxs, - links=links - ) - foo1_builder = self._create_builder(bucket.foos['my_foo1']) - foo3_builder = self._create_builder(bucket.links['my_foo3']) - qux1_builder = self._create_builder(bucket.quxs['my_qux1']) + links = [SimpleFoo("my_foo3")] + bucket = SimpleBucket(name="test_bucket", foos=foos, quxs=quxs, links=links) + foo1_builder = self._create_builder(bucket.foos["my_foo1"]) + foo3_builder = self._create_builder(bucket.links["my_foo3"]) + qux1_builder = self._create_builder(bucket.quxs["my_qux1"]) foo3_link_builder = LinkBuilder(builder=foo3_builder) bucket_builder = GroupBuilder( - name='test_bucket', - groups={'my_foo1': foo1_builder}, - datasets={'my_qux1': qux1_builder}, - links={'my_foo1': foo3_link_builder}, - attributes={'namespace': CORE_NAMESPACE, - 'data_type': 'SimpleBucket', - 'object_id': bucket.object_id} + name="test_bucket", + groups={"my_foo1": foo1_builder}, + datasets={"my_qux1": qux1_builder}, + links={"my_foo1": foo3_link_builder}, + attributes={ + "namespace": CORE_NAMESPACE, + "data_type": "SimpleBucket", + "object_id": bucket.object_id, + }, ) return bucket, bucket_builder def get_zero_bucket_test(self): - bucket = SimpleBucket( - name='test_bucket' - ) + bucket = SimpleBucket(name="test_bucket") bucket_builder = GroupBuilder( - name='test_bucket', - attributes={'namespace': CORE_NAMESPACE, - 'data_type': 'SimpleBucket', - 'object_id': bucket.object_id} + name="test_bucket", + attributes={ + "namespace": CORE_NAMESPACE, + "data_type": "SimpleBucket", + "object_id": bucket.object_id, + }, ) return bucket, bucket_builder def get_mismatch_bucket_test(self): - foos = [NotSimpleFoo('my_foo1'), NotSimpleFoo('my_foo2')] - quxs = [NotSimpleQux('my_qux1', data=[1, 2, 3]), NotSimpleQux('my_qux2', data=[4, 5, 6])] - links = [NotSimpleFoo('my_foo1'), NotSimpleFoo('my_foo2')] - bucket = SimpleBucket( - name='test_bucket', - foos=foos, - quxs=quxs, - links=links - ) + foos = [NotSimpleFoo("my_foo1"), NotSimpleFoo("my_foo2")] + quxs = [ + NotSimpleQux("my_qux1", data=[1, 2, 3]), + NotSimpleQux("my_qux2", data=[4, 5, 6]), + ] + links = [NotSimpleFoo("my_foo1"), NotSimpleFoo("my_foo2")] + bucket = SimpleBucket(name="test_bucket", foos=foos, quxs=quxs, links=links) bucket_builder = GroupBuilder( - name='test_bucket', - attributes={'namespace': CORE_NAMESPACE, - 'data_type': 'SimpleBucket', - 'object_id': bucket.object_id} + name="test_bucket", + attributes={ + "namespace": CORE_NAMESPACE, + "data_type": "SimpleBucket", + "object_id": bucket.object_id, + }, ) return bucket, bucket_builder class ZeroOrManyMixin: - def setUp(self): specs = self.create_specs(ZERO_OR_MANY) self.setUpManager(specs) @@ -638,7 +653,6 @@ def test_build_mismatch(self): class OneOrManyMixin: - def setUp(self): specs = self.create_specs(ONE_OR_MANY) self.setUpManager(specs) @@ -673,7 +687,6 @@ def test_build_mismatch(self): class OneMixin: - def setUp(self): specs = self.create_specs(DEF_QUANTITY) self.setUpManager(specs) @@ -681,7 +694,7 @@ def setUp(self): def test_build_two(self): """Test building a container which contains multiple containers as the spec allows.""" bucket, bucket_builder = self.get_two_bucket_test() - msg = r"SimpleBucket 'test_bucket' has 2 values for attribute '.*' but spec allows 1\." + msg = r"SimpleBucket 'test_bucket' has 2 values for attribute '.*' but spec" r" allows 1\." with self.assertWarnsRegex(IncorrectQuantityBuildWarning, msg): builder = self.manager.build(bucket) self.assertDictEqual(builder, bucket_builder) @@ -710,7 +723,6 @@ def test_build_mismatch(self): class TwoMixin: - def setUp(self): specs = self.create_specs(2) self.setUpManager(specs) @@ -724,7 +736,7 @@ def test_build_two(self): def test_build_one(self): """Test building a container which contains one container as the spec allows.""" bucket, bucket_builder = self.get_one_bucket_test() - msg = r"SimpleBucket 'test_bucket' has 1 values for attribute '.*' but spec allows 2\." + msg = r"SimpleBucket 'test_bucket' has 1 values for attribute '.*' but spec" r" allows 2\." with self.assertWarnsRegex(IncorrectQuantityBuildWarning, msg): builder = self.manager.build(bucket) self.assertDictEqual(builder, bucket_builder) @@ -747,7 +759,6 @@ def test_build_mismatch(self): class ZeroOrOneMixin: - def setUp(self): specs = self.create_specs(ZERO_OR_ONE) self.setUpManager(specs) @@ -755,7 +766,7 @@ def setUp(self): def test_build_two(self): """Test building a container which contains multiple containers as the spec allows.""" bucket, bucket_builder = self.get_two_bucket_test() - msg = r"SimpleBucket 'test_bucket' has 2 values for attribute '.*' but spec allows '\?'\." + msg = r"SimpleBucket 'test_bucket' has 2 values for attribute '.*' but spec" r" allows '\?'\." with self.assertWarnsRegex(IncorrectQuantityBuildWarning, msg): builder = self.manager.build(bucket) self.assertDictEqual(builder, bucket_builder) @@ -781,120 +792,123 @@ def test_build_mismatch(self): # Untyped group with included groups / included datasets / links with quantity {'*', '+', 1, 2 '?'} + class TestBuildZeroOrManyTypeIncUntypedGroup(ZeroOrManyMixin, TypeIncUntypedGroupMixin, BuildQuantityMixin, TestCase): - """Test building a group that has an untyped subgroup with a data type inc subgroup/dataset/link with quantity '*' - """ + """Test a group that has an untyped subgroup with a data type inc subgroup/dataset/link with quantity '*'""" + pass class TestBuildOneOrManyTypeIncUntypedGroup(OneOrManyMixin, TypeIncUntypedGroupMixin, BuildQuantityMixin, TestCase): - """Test building a group that has an untyped subgroup with a data type inc subgroup/dataset/link with quantity '+' - """ + """Test a group that has an untyped subgroup with a data type inc subgroup/dataset/link with quantity '+'""" + pass class TestBuildOneTypeIncUntypedGroup(OneMixin, TypeIncUntypedGroupMixin, BuildQuantityMixin, TestCase): - """Test building a group that has an untyped subgroup with a data type inc subgroup/dataset/link with quantity 1 - """ + """Test a group that has an untyped subgroup with a data type inc subgroup/dataset/link with quantity 1""" + pass class TestBuildTwoTypeIncUntypedGroup(TwoMixin, TypeIncUntypedGroupMixin, BuildQuantityMixin, TestCase): - """Test building a group that has an untyped subgroup with a data type inc subgroup/dataset/link with quantity 2 - """ + """Test a group that has an untyped subgroup with a data type inc subgroup/dataset/link with quantity 2""" + pass class TestBuildZeroOrOneTypeIncUntypedGroup(ZeroOrOneMixin, TypeIncUntypedGroupMixin, BuildQuantityMixin, TestCase): - """Test building a group that has an untyped subgroup with a data type inc subgroup/dataset/link with quantity '?' - """ + """Test a group that has an untyped subgroup with a data type inc subgroup/dataset/link with quantity '?'""" + pass # Nested type definition of group/dataset with quantity {'*', '+', 1, 2, '?'} + class TestBuildZeroOrManyTypeDef(ZeroOrManyMixin, TypeDefMixin, BuildQuantityMixin, TestCase): - """Test building a group that has a nested type def with quantity '*' - """ + """Test building a group that has a nested type def with quantity '*'""" + pass class TestBuildOneOrManyTypeDef(OneOrManyMixin, TypeDefMixin, BuildQuantityMixin, TestCase): - """Test building a group that has a nested type def with quantity '+' - """ + """Test building a group that has a nested type def with quantity '+'""" + pass class TestBuildOneTypeDef(OneMixin, TypeDefMixin, BuildQuantityMixin, TestCase): - """Test building a group that has a nested type def with quantity 1 - """ + """Test building a group that has a nested type def with quantity 1""" + pass class TestBuildTwoTypeDef(TwoMixin, TypeDefMixin, BuildQuantityMixin, TestCase): - """Test building a group that has a nested type def with quantity 2 - """ + """Test building a group that has a nested type def with quantity 2""" + pass class TestBuildZeroOrOneTypeDef(ZeroOrOneMixin, TypeDefMixin, BuildQuantityMixin, TestCase): - """Test building a group that has a nested type def with quantity '?' - """ + """Test building a group that has a nested type def with quantity '?'""" + pass # Included groups / included datasets / links with quantity {'*', '+', 1, 2, '?'} + class TestBuildZeroOrManyTypeInc(ZeroOrManyMixin, TypeIncMixin, BuildQuantityMixin, TestCase): - """Test building a group that has a data type inc subgroup/dataset/link with quantity '*' - """ + """Test building a group that has a data type inc subgroup/dataset/link with quantity '*'""" + pass class TestBuildOneOrManyTypeInc(OneOrManyMixin, TypeIncMixin, BuildQuantityMixin, TestCase): - """Test building a group that has a data type inc subgroup/dataset/link with quantity '+' - """ + """Test building a group that has a data type inc subgroup/dataset/link with quantity '+'""" + pass class TestBuildOneTypeInc(OneMixin, TypeIncMixin, BuildQuantityMixin, TestCase): - """Test building a group that has a data type inc subgroup/dataset/link with quantity 1 - """ + """Test building a group that has a data type inc subgroup/dataset/link with quantity 1""" + pass class TestBuildTwoTypeInc(TwoMixin, TypeIncMixin, BuildQuantityMixin, TestCase): - """Test building a group that has a data type inc subgroup/dataset/link with quantity 2 - """ + """Test building a group that has a data type inc subgroup/dataset/link with quantity 2""" + pass class TestBuildZeroOrOneTypeInc(ZeroOrOneMixin, TypeIncMixin, BuildQuantityMixin, TestCase): - """Test building a group that has a data type inc subgroup/dataset/link with quantity '?' - """ + """Test building a group that has a data type inc subgroup/dataset/link with quantity '?'""" + pass # Untyped group/dataset with quantity {1, '?'} -class UntypedMixin: +class UntypedMixin: def setUpManager(self, specs): spec_catalog = SpecCatalog() - schema_file = 'test.yaml' + schema_file = "test.yaml" for s in specs: spec_catalog.register_spec(s, schema_file) namespace = SpecNamespace( - doc='a test namespace', + doc="a test namespace", name=CORE_NAMESPACE, - schema=[{'source': schema_file}], - version='0.1.0', - catalog=spec_catalog + schema=[{"source": schema_file}], + version="0.1.0", + catalog=spec_catalog, ) namespace_catalog = NamespaceCatalog() namespace_catalog.add_namespace(CORE_NAMESPACE, namespace) type_map = TypeMap(namespace_catalog) - type_map.register_container_type(CORE_NAMESPACE, 'BasicBucket', BasicBucket) + type_map.register_container_type(CORE_NAMESPACE, "BasicBucket", BasicBucket) self.manager = BuildManager(type_map) def create_specs(self, quantity): @@ -904,28 +918,28 @@ def create_specs(self, quantity): # - [quantity] untyped array dataset # quantity can be only '?' or 1 untyped_group_spec = GroupSpec( - doc='A test group specification with no data type', - name='untyped_group', + doc="A test group specification with no data type", + name="untyped_group", quantity=quantity, ) untyped_dataset_spec = DatasetSpec( - doc='A test dataset specification with no data type', - name='untyped_dataset', - dtype='int', + doc="A test dataset specification with no data type", + name="untyped_dataset", + dtype="int", quantity=quantity, ) untyped_array_dataset_spec = DatasetSpec( - doc='A test dataset specification with no data type', - name='untyped_array_dataset', - dtype='int', + doc="A test dataset specification with no data type", + name="untyped_array_dataset", + dtype="int", dims=[None], shape=[None], quantity=quantity, ) basic_bucket_spec = GroupSpec( - doc='A test group specification for a data type containing data type', + doc="A test group specification for a data type containing data type", name="test_bucket", - data_type_def='BasicBucket', + data_type_def="BasicBucket", groups=[untyped_group_spec], datasets=[untyped_dataset_spec, untyped_array_dataset_spec], ) @@ -933,43 +947,49 @@ def create_specs(self, quantity): class TestBuildOneUntyped(UntypedMixin, TestCase): - """Test building a group that has an untyped subgroup/dataset with quantity 1. - """ + """Test building a group that has an untyped subgroup/dataset with quantity 1.""" + def setUp(self): specs = self.create_specs(DEF_QUANTITY) self.setUpManager(specs) def test_build_data(self): """Test building a container which contains an untyped empty subgroup and an untyped non-empty dataset.""" - bucket = BasicBucket(name='test_bucket', untyped_dataset=3, untyped_array_dataset=[3]) + bucket = BasicBucket(name="test_bucket", untyped_dataset=3, untyped_array_dataset=[3]) # a required untyped empty group builder will be created by default - untyped_group_builder = GroupBuilder(name='untyped_group') - untyped_dataset_builder = DatasetBuilder(name='untyped_dataset', data=3) - untyped_array_dataset_builder = DatasetBuilder(name='untyped_array_dataset', data=[3]) + untyped_group_builder = GroupBuilder(name="untyped_group") + untyped_dataset_builder = DatasetBuilder(name="untyped_dataset", data=3) + untyped_array_dataset_builder = DatasetBuilder(name="untyped_array_dataset", data=[3]) bucket_builder = GroupBuilder( - name='test_bucket', - groups={'untyped_group': untyped_group_builder}, - datasets={'untyped_dataset': untyped_dataset_builder, - 'untyped_array_dataset': untyped_array_dataset_builder}, - attributes={'namespace': CORE_NAMESPACE, - 'data_type': 'BasicBucket', - 'object_id': bucket.object_id} + name="test_bucket", + groups={"untyped_group": untyped_group_builder}, + datasets={ + "untyped_dataset": untyped_dataset_builder, + "untyped_array_dataset": untyped_array_dataset_builder, + }, + attributes={ + "namespace": CORE_NAMESPACE, + "data_type": "BasicBucket", + "object_id": bucket.object_id, + }, ) builder = self.manager.build(bucket) self.assertDictEqual(builder, bucket_builder) def test_build_empty_data(self): """Test building a container which contains an untyped empty subgroup and an untyped empty dataset.""" - bucket = BasicBucket(name='test_bucket') + bucket = BasicBucket(name="test_bucket") # a required untyped empty group builder will be created by default - untyped_group_builder = GroupBuilder(name='untyped_group') + untyped_group_builder = GroupBuilder(name="untyped_group") # a required untyped empty dataset builder will NOT be created by default bucket_builder = GroupBuilder( - name='test_bucket', - groups={'untyped_group': untyped_group_builder}, - attributes={'namespace': CORE_NAMESPACE, - 'data_type': 'BasicBucket', - 'object_id': bucket.object_id} + name="test_bucket", + groups={"untyped_group": untyped_group_builder}, + attributes={ + "namespace": CORE_NAMESPACE, + "data_type": "BasicBucket", + "object_id": bucket.object_id, + }, ) msg = "BasicBucket 'test_bucket' is missing required value for attribute 'untyped_dataset'." # also raises "BasicBucket 'test_bucket' is missing required value for attribute 'untyped_array_dataset'." @@ -979,39 +999,45 @@ def test_build_empty_data(self): class TestBuildZeroOrOneUntyped(UntypedMixin, TestCase): - """Test building a group that has an untyped subgroup/dataset with quantity '?'. - """ + """Test building a group that has an untyped subgroup/dataset with quantity '?'.""" + def setUp(self): specs = self.create_specs(ZERO_OR_ONE) self.setUpManager(specs) def test_build_data(self): """Test building a container which contains an untyped empty subgroup and an untyped non-empty dataset.""" - bucket = BasicBucket(name='test_bucket', untyped_dataset=3, untyped_array_dataset=[3]) + bucket = BasicBucket(name="test_bucket", untyped_dataset=3, untyped_array_dataset=[3]) # an optional untyped empty group builder will NOT be created by default - untyped_dataset_builder = DatasetBuilder(name='untyped_dataset', data=3) - untyped_array_dataset_builder = DatasetBuilder(name='untyped_array_dataset', data=[3]) + untyped_dataset_builder = DatasetBuilder(name="untyped_dataset", data=3) + untyped_array_dataset_builder = DatasetBuilder(name="untyped_array_dataset", data=[3]) bucket_builder = GroupBuilder( - name='test_bucket', - datasets={'untyped_dataset': untyped_dataset_builder, - 'untyped_array_dataset': untyped_array_dataset_builder}, - attributes={'namespace': CORE_NAMESPACE, - 'data_type': 'BasicBucket', - 'object_id': bucket.object_id} + name="test_bucket", + datasets={ + "untyped_dataset": untyped_dataset_builder, + "untyped_array_dataset": untyped_array_dataset_builder, + }, + attributes={ + "namespace": CORE_NAMESPACE, + "data_type": "BasicBucket", + "object_id": bucket.object_id, + }, ) builder = self.manager.build(bucket) self.assertDictEqual(builder, bucket_builder) def test_build_empty_data(self): """Test building a container which contains an untyped empty subgroup and an untyped empty dataset.""" - bucket = BasicBucket(name='test_bucket') + bucket = BasicBucket(name="test_bucket") # an optional untyped empty group builder will NOT be created by default # an optional untyped empty dataset builder will NOT be created by default bucket_builder = GroupBuilder( - name='test_bucket', - attributes={'namespace': CORE_NAMESPACE, - 'data_type': 'BasicBucket', - 'object_id': bucket.object_id} + name="test_bucket", + attributes={ + "namespace": CORE_NAMESPACE, + "data_type": "BasicBucket", + "object_id": bucket.object_id, + }, ) builder = self.manager.build(bucket) self.assertDictEqual(builder, bucket_builder) @@ -1019,9 +1045,9 @@ def test_build_empty_data(self): # Multiple allowed types + class TestBuildMultiTypeInc(BuildQuantityMixin, TestCase): - """Test build process when a groupspec allows multiple groups/datasets/links with different data types / targets. - """ + """Test build process when a groupspec allows multiple groups/datasets/links with different data types / targets.""" def setUp(self): specs = self.create_specs(ZERO_OR_MANY) @@ -1033,58 +1059,58 @@ def create_specs(self, quantity): # - [quantity] datasets of data_type_inc SimpleQux and [quantity] datasets of data_type_inc NotSimpleQux # - [quantity] links of target_type SimpleFoo and [quantity] links of target_type NotSimpleFoo foo_spec = GroupSpec( - doc='A test group specification with a data type', - data_type_def='SimpleFoo', + doc="A test group specification with a data type", + data_type_def="SimpleFoo", ) not_foo_spec = GroupSpec( - doc='A test group specification with a data type', - data_type_def='NotSimpleFoo', + doc="A test group specification with a data type", + data_type_def="NotSimpleFoo", ) qux_spec = DatasetSpec( - doc='A test group specification with a data type', - data_type_def='SimpleQux', + doc="A test group specification with a data type", + data_type_def="SimpleQux", ) not_qux_spec = DatasetSpec( - doc='A test group specification with a data type', - data_type_def='NotSimpleQux', + doc="A test group specification with a data type", + data_type_def="NotSimpleQux", ) foo_inc_spec = GroupSpec( - doc='the SimpleFoos in this bucket', - data_type_inc='SimpleFoo', - quantity=quantity + doc="the SimpleFoos in this bucket", + data_type_inc="SimpleFoo", + quantity=quantity, ) not_foo_inc_spec = GroupSpec( - doc='the SimpleFoos in this bucket', - data_type_inc='NotSimpleFoo', - quantity=quantity + doc="the SimpleFoos in this bucket", + data_type_inc="NotSimpleFoo", + quantity=quantity, ) qux_inc_spec = DatasetSpec( - doc='the SimpleQuxs in this bucket', - data_type_inc='SimpleQux', - quantity=quantity + doc="the SimpleQuxs in this bucket", + data_type_inc="SimpleQux", + quantity=quantity, ) not_qux_inc_spec = DatasetSpec( - doc='the SimpleQuxs in this bucket', - data_type_inc='NotSimpleQux', - quantity=quantity + doc="the SimpleQuxs in this bucket", + data_type_inc="NotSimpleQux", + quantity=quantity, ) foo_link_spec = LinkSpec( - doc='the links in this bucket', - target_type='SimpleFoo', - quantity=quantity + doc="the links in this bucket", + target_type="SimpleFoo", + quantity=quantity, ) not_foo_link_spec = LinkSpec( - doc='the links in this bucket', - target_type='NotSimpleFoo', - quantity=quantity + doc="the links in this bucket", + target_type="NotSimpleFoo", + quantity=quantity, ) bucket_spec = GroupSpec( - doc='A test group specification for a data type containing data type', + doc="A test group specification for a data type containing data type", name="test_bucket", - data_type_def='SimpleBucket', + data_type_def="SimpleBucket", groups=[foo_inc_spec, not_foo_inc_spec], datasets=[qux_inc_spec, not_qux_inc_spec], - links=[foo_link_spec, not_foo_link_spec] + links=[foo_link_spec, not_foo_link_spec], ) return [foo_spec, not_foo_spec, qux_spec, not_qux_spec, bucket_spec] @@ -1092,45 +1118,42 @@ def setUpBucketMapper(self): class BucketMapper(ObjectMapper): def __init__(self, spec): super().__init__(spec) - self.map_spec('foos', spec.get_data_type('SimpleFoo')) - self.map_spec('foos', spec.get_data_type('NotSimpleFoo')) - self.map_spec('quxs', spec.get_data_type('SimpleQux')) - self.map_spec('quxs', spec.get_data_type('NotSimpleQux')) - self.map_spec('links', spec.links[0]) - self.map_spec('links', spec.links[1]) + self.map_spec("foos", spec.get_data_type("SimpleFoo")) + self.map_spec("foos", spec.get_data_type("NotSimpleFoo")) + self.map_spec("quxs", spec.get_data_type("SimpleQux")) + self.map_spec("quxs", spec.get_data_type("NotSimpleQux")) + self.map_spec("links", spec.links[0]) + self.map_spec("links", spec.links[1]) return BucketMapper def get_two_bucket_test(self): - foos = [SimpleFoo('my_foo1'), NotSimpleFoo('my_foo2')] - quxs = [SimpleQux('my_qux1', data=[1, 2, 3]), NotSimpleQux('my_qux2', data=[4, 5, 6])] + foos = [SimpleFoo("my_foo1"), NotSimpleFoo("my_foo2")] + quxs = [ + SimpleQux("my_qux1", data=[1, 2, 3]), + NotSimpleQux("my_qux2", data=[4, 5, 6]), + ] # NOTE: unlike in the other tests, links cannot map to the same foos in bucket because of a name clash - links = [SimpleFoo('my_foo3'), NotSimpleFoo('my_foo4')] - bucket = SimpleBucket( - name='test_bucket', - foos=foos, - quxs=quxs, - links=links - ) - foo1_builder = self._create_builder(bucket.foos['my_foo1']) - foo2_builder = self._create_builder(bucket.foos['my_foo2']) - foo3_builder = self._create_builder(bucket.links['my_foo3']) - foo4_builder = self._create_builder(bucket.links['my_foo4']) - qux1_builder = self._create_builder(bucket.quxs['my_qux1']) - qux2_builder = self._create_builder(bucket.quxs['my_qux2']) + links = [SimpleFoo("my_foo3"), NotSimpleFoo("my_foo4")] + bucket = SimpleBucket(name="test_bucket", foos=foos, quxs=quxs, links=links) + foo1_builder = self._create_builder(bucket.foos["my_foo1"]) + foo2_builder = self._create_builder(bucket.foos["my_foo2"]) + foo3_builder = self._create_builder(bucket.links["my_foo3"]) + foo4_builder = self._create_builder(bucket.links["my_foo4"]) + qux1_builder = self._create_builder(bucket.quxs["my_qux1"]) + qux2_builder = self._create_builder(bucket.quxs["my_qux2"]) foo3_link_builder = LinkBuilder(builder=foo3_builder) foo4_link_builder = LinkBuilder(builder=foo4_builder) bucket_builder = GroupBuilder( - name='test_bucket', - groups={'my_foo1': foo1_builder, - 'my_foo2': foo2_builder}, - datasets={'my_qux1': qux1_builder, - 'my_qux2': qux2_builder}, - links={'my_foo3': foo3_link_builder, - 'my_foo4': foo4_link_builder}, - attributes={'namespace': CORE_NAMESPACE, - 'data_type': 'SimpleBucket', - 'object_id': bucket.object_id} + name="test_bucket", + groups={"my_foo1": foo1_builder, "my_foo2": foo2_builder}, + datasets={"my_qux1": qux1_builder, "my_qux2": qux2_builder}, + links={"my_foo3": foo3_link_builder, "my_foo4": foo4_link_builder}, + attributes={ + "namespace": CORE_NAMESPACE, + "data_type": "SimpleBucket", + "object_id": bucket.object_id, + }, ) return bucket, bucket_builder diff --git a/tests/unit/build_tests/test_builder.py b/tests/unit/build_tests/test_builder.py index a35dc64ac..44e0b9569 100644 --- a/tests/unit/build_tests/test_builder.py +++ b/tests/unit/build_tests/test_builder.py @@ -1,108 +1,112 @@ -from hdmf.build import GroupBuilder, DatasetBuilder, LinkBuilder, ReferenceBuilder, RegionBuilder +from hdmf.build import ( + DatasetBuilder, + GroupBuilder, + LinkBuilder, + ReferenceBuilder, + RegionBuilder, +) from hdmf.testing import TestCase class TestGroupBuilder(TestCase): - def test_constructor(self): - gb1 = GroupBuilder('gb1', source='source') - gb2 = GroupBuilder('gb2', parent=gb1) - self.assertIs(gb1.name, 'gb1') + gb1 = GroupBuilder("gb1", source="source") + gb2 = GroupBuilder("gb2", parent=gb1) + self.assertIs(gb1.name, "gb1") self.assertIsNone(gb1.parent) - self.assertEqual(gb1.source, 'source') + self.assertEqual(gb1.source, "source") self.assertIs(gb2.parent, gb1) def test_repr(self): - gb1 = GroupBuilder('gb1') + gb1 = GroupBuilder("gb1") expected = "gb1 GroupBuilder {'attributes': {}, 'groups': {}, 'datasets': {}, 'links': {}}" self.assertEqual(gb1.__repr__(), expected) def test_set_source(self): """Test that setting source sets the children builder source.""" - gb1 = GroupBuilder('gb1') - db = DatasetBuilder('db', list(range(10))) - lb = LinkBuilder(gb1, 'lb') - gb2 = GroupBuilder('gb1', {'gb1': gb1}, {'db': db}, {}, {'lb': lb}) - gb2.source = 'source' - self.assertEqual(gb2.source, 'source') - self.assertEqual(gb1.source, 'source') - self.assertEqual(db.source, 'source') - self.assertEqual(lb.source, 'source') + gb1 = GroupBuilder("gb1") + db = DatasetBuilder("db", list(range(10))) + lb = LinkBuilder(gb1, "lb") + gb2 = GroupBuilder("gb1", {"gb1": gb1}, {"db": db}, {}, {"lb": lb}) + gb2.source = "source" + self.assertEqual(gb2.source, "source") + self.assertEqual(gb1.source, "source") + self.assertEqual(db.source, "source") + self.assertEqual(lb.source, "source") def test_set_source_no_reset(self): """Test that setting source does not set the children builder source if children already have a source.""" - gb1 = GroupBuilder('gb1', source='original') - db = DatasetBuilder('db', list(range(10)), source='original') - lb = LinkBuilder(gb1, 'lb', source='original') - gb2 = GroupBuilder('gb1', {'gb1': gb1}, {'db': db}, {}, {'lb': lb}) - gb2.source = 'source' - self.assertEqual(gb1.source, 'original') - self.assertEqual(db.source, 'original') - self.assertEqual(lb.source, 'original') + gb1 = GroupBuilder("gb1", source="original") + db = DatasetBuilder("db", list(range(10)), source="original") + lb = LinkBuilder(gb1, "lb", source="original") + gb2 = GroupBuilder("gb1", {"gb1": gb1}, {"db": db}, {}, {"lb": lb}) + gb2.source = "source" + self.assertEqual(gb1.source, "original") + self.assertEqual(db.source, "original") + self.assertEqual(lb.source, "original") def test_constructor_dset_none(self): - gb1 = GroupBuilder('gb1', datasets={'empty': None}) + gb1 = GroupBuilder("gb1", datasets={"empty": None}) self.assertEqual(len(gb1.datasets), 0) def test_set_location(self): - gb1 = GroupBuilder('gb1') - gb1.location = 'location' - self.assertEqual(gb1.location, 'location') + gb1 = GroupBuilder("gb1") + gb1.location = "location" + self.assertEqual(gb1.location, "location") def test_overwrite_location(self): - gb1 = GroupBuilder('gb1') - gb1.location = 'location' - gb1.location = 'new location' - self.assertEqual(gb1.location, 'new location') + gb1 = GroupBuilder("gb1") + gb1.location = "location" + gb1.location = "new location" + self.assertEqual(gb1.location, "new location") class TestGroupBuilderSetters(TestCase): - def test_set_attribute(self): - gb = GroupBuilder('gb') - gb.set_attribute('key', 'value') - self.assertIn('key', gb.obj_type) - self.assertIn('key', gb.attributes) - self.assertEqual(gb['key'], 'value') + gb = GroupBuilder("gb") + gb.set_attribute("key", "value") + self.assertIn("key", gb.obj_type) + self.assertIn("key", gb.attributes) + self.assertEqual(gb["key"], "value") def test_set_group(self): - gb1 = GroupBuilder('gb1') - gb2 = GroupBuilder('gb2') + gb1 = GroupBuilder("gb1") + gb2 = GroupBuilder("gb2") gb1.set_group(gb2) self.assertIs(gb2.parent, gb1) - self.assertIn('gb2', gb1.obj_type) - self.assertIn('gb2', gb1.groups) - self.assertIs(gb1['gb2'], gb2) + self.assertIn("gb2", gb1.obj_type) + self.assertIn("gb2", gb1.groups) + self.assertIs(gb1["gb2"], gb2) def test_set_dataset(self): - gb = GroupBuilder('gb') - db = DatasetBuilder('db', list(range(10))) + gb = GroupBuilder("gb") + db = DatasetBuilder("db", list(range(10))) gb.set_dataset(db) self.assertIs(db.parent, gb) - self.assertIn('db', gb.obj_type) - self.assertIn('db', gb.datasets) - self.assertIs(gb['db'], db) + self.assertIn("db", gb.obj_type) + self.assertIn("db", gb.datasets) + self.assertIs(gb["db"], db) def test_set_link(self): - gb1 = GroupBuilder('gb1') - gb2 = GroupBuilder('gb2') + gb1 = GroupBuilder("gb1") + gb2 = GroupBuilder("gb2") lb = LinkBuilder(gb2) gb1.set_link(lb) self.assertIs(lb.parent, gb1) - self.assertIn('gb2', gb1.obj_type) - self.assertIn('gb2', gb1.links) - self.assertIs(gb1['gb2'], lb) + self.assertIn("gb2", gb1.obj_type) + self.assertIn("gb2", gb1.links) + self.assertIs(gb1["gb2"], lb) def test_setitem_disabled(self): """Test __setitem__ is disabled""" - gb = GroupBuilder('gb') + gb = GroupBuilder("gb") with self.assertRaises(NotImplementedError): - gb['key'] = 'value' + gb["key"] = "value" def test_set_exists_wrong_type(self): - gb1 = GroupBuilder('gb1') - gb2 = GroupBuilder('gb2') - db = DatasetBuilder('gb2') + gb1 = GroupBuilder("gb1") + gb2 = GroupBuilder("gb2") + db = DatasetBuilder("gb2") gb1.set_group(gb2) msg = "'gb2' already exists in gb1.groups, cannot set in datasets." with self.assertRaisesWith(ValueError, msg): @@ -110,95 +114,93 @@ def test_set_exists_wrong_type(self): class TestGroupBuilderGetters(TestCase): - def setUp(self): - self.subgroup1 = GroupBuilder('subgroup1') - self.dataset1 = DatasetBuilder('dataset1', list(range(10))) - self.link1 = LinkBuilder(self.subgroup1, 'link1') + self.subgroup1 = GroupBuilder("subgroup1") + self.dataset1 = DatasetBuilder("dataset1", list(range(10))) + self.link1 = LinkBuilder(self.subgroup1, "link1") self.int_attr = 1 self.str_attr = "my_str" - self.group1 = GroupBuilder('group1', {'subgroup1': self.subgroup1}) + self.group1 = GroupBuilder("group1", {"subgroup1": self.subgroup1}) self.gb = GroupBuilder( - name='gb', - groups={'group1': self.group1}, - datasets={'dataset1': self.dataset1}, - attributes={'int_attr': self.int_attr, - 'str_attr': self.str_attr}, - links={'link1': self.link1} + name="gb", + groups={"group1": self.group1}, + datasets={"dataset1": self.dataset1}, + attributes={"int_attr": self.int_attr, "str_attr": self.str_attr}, + links={"link1": self.link1}, ) def test_path(self): - self.assertEqual(self.subgroup1.path, 'gb/group1/subgroup1') - self.assertEqual(self.dataset1.path, 'gb/dataset1') - self.assertEqual(self.link1.path, 'gb/link1') - self.assertEqual(self.group1.path, 'gb/group1') - self.assertEqual(self.gb.path, 'gb') + self.assertEqual(self.subgroup1.path, "gb/group1/subgroup1") + self.assertEqual(self.dataset1.path, "gb/dataset1") + self.assertEqual(self.link1.path, "gb/link1") + self.assertEqual(self.group1.path, "gb/group1") + self.assertEqual(self.gb.path, "gb") def test_getitem_group(self): """Test __getitem__ for groups""" - self.assertIs(self.gb['group1'], self.group1) + self.assertIs(self.gb["group1"], self.group1) def test_getitem_group_deeper(self): """Test __getitem__ for groups deeper in hierarchy""" - self.assertIs(self.gb['group1/subgroup1'], self.subgroup1) + self.assertIs(self.gb["group1/subgroup1"], self.subgroup1) def test_getitem_dataset(self): """Test __getitem__ for datasets""" - self.assertIs(self.gb['dataset1'], self.dataset1) + self.assertIs(self.gb["dataset1"], self.dataset1) def test_getitem_attr(self): """Test __getitem__ for attributes""" - self.assertEqual(self.gb['int_attr'], self.int_attr) - self.assertEqual(self.gb['str_attr'], self.str_attr) + self.assertEqual(self.gb["int_attr"], self.int_attr) + self.assertEqual(self.gb["str_attr"], self.str_attr) def test_getitem_invalid_key(self): """Test __getitem__ for invalid key""" with self.assertRaises(KeyError): - self.gb['invalid_key'] + self.gb["invalid_key"] def test_getitem_invalid_key_deeper(self): """Test __getitem__ for invalid key""" with self.assertRaises(KeyError): - self.gb['group/invalid_key'] + self.gb["group/invalid_key"] def test_getitem_link(self): """Test __getitem__ for links""" - self.assertIs(self.gb['link1'], self.link1) + self.assertIs(self.gb["link1"], self.link1) def test_get_group(self): """Test get() for groups""" - self.assertIs(self.gb.get('group1'), self.group1) + self.assertIs(self.gb.get("group1"), self.group1) def test_get_group_deeper(self): """Test get() for groups deeper in hierarchy""" - self.assertIs(self.gb.get('group1/subgroup1'), self.subgroup1) + self.assertIs(self.gb.get("group1/subgroup1"), self.subgroup1) def test_get_dataset(self): """Test get() for datasets""" - self.assertIs(self.gb.get('dataset1'), self.dataset1) + self.assertIs(self.gb.get("dataset1"), self.dataset1) def test_get_attr(self): """Test get() for attributes""" - self.assertEqual(self.gb.get('int_attr'), self.int_attr) - self.assertEqual(self.gb.get('str_attr'), self.str_attr) + self.assertEqual(self.gb.get("int_attr"), self.int_attr) + self.assertEqual(self.gb.get("str_attr"), self.str_attr) def test_get_link(self): """Test get() for links""" - self.assertIs(self.gb.get('link1'), self.link1) + self.assertIs(self.gb.get("link1"), self.link1) def test_get_invalid_key(self): """Test get() for invalid key""" - self.assertIs(self.gb.get('invalid_key'), None) + self.assertIs(self.gb.get("invalid_key"), None) def test_items(self): """Test items()""" items = ( - ('group1', self.group1), - ('dataset1', self.dataset1), - ('int_attr', self.int_attr), - ('str_attr', self.str_attr), - ('link1', self.link1), + ("group1", self.group1), + ("dataset1", self.dataset1), + ("int_attr", self.int_attr), + ("str_attr", self.str_attr), + ("link1", self.link1), ) # self.assertSetEqual(items, set(self.gb.items())) try: @@ -209,11 +211,11 @@ def test_items(self): def test_keys(self): """Test keys()""" keys = ( - 'group1', - 'dataset1', - 'int_attr', - 'str_attr', - 'link1', + "group1", + "dataset1", + "int_attr", + "str_attr", + "link1", ) try: self.assertCountEqual(keys, self.gb.keys()) @@ -236,168 +238,163 @@ def test_values(self): class TestGroupBuilderIsEmpty(TestCase): - def test_is_empty_true(self): """Test empty when group has nothing in it""" - gb = GroupBuilder('gb') + gb = GroupBuilder("gb") self.assertTrue(gb.is_empty()) def test_is_empty_true_group_empty(self): """Test is_empty() when group has an empty subgroup""" - gb1 = GroupBuilder('my_subgroup') - gb2 = GroupBuilder('gb', {'my_subgroup': gb1}) + gb1 = GroupBuilder("my_subgroup") + gb2 = GroupBuilder("gb", {"my_subgroup": gb1}) self.assertTrue(gb2.is_empty()) def test_is_empty_false_dataset(self): """Test is_empty() when group has a dataset""" - gb = GroupBuilder('gb', datasets={'my_dataset': DatasetBuilder('my_dataset')}) + gb = GroupBuilder("gb", datasets={"my_dataset": DatasetBuilder("my_dataset")}) self.assertFalse(gb.is_empty()) def test_is_empty_false_group_dataset(self): """Test is_empty() when group has a subgroup with a dataset""" - gb1 = GroupBuilder('my_subgroup', datasets={'my_dataset': DatasetBuilder('my_dataset')}) - gb2 = GroupBuilder('gb', {'my_subgroup': gb1}) + gb1 = GroupBuilder("my_subgroup", datasets={"my_dataset": DatasetBuilder("my_dataset")}) + gb2 = GroupBuilder("gb", {"my_subgroup": gb1}) self.assertFalse(gb2.is_empty()) def test_is_empty_false_attribute(self): """Test is_empty() when group has an attribute""" - gb = GroupBuilder('gb', attributes={'my_attr': 'attr_value'}) + gb = GroupBuilder("gb", attributes={"my_attr": "attr_value"}) self.assertFalse(gb.is_empty()) def test_is_empty_false_group_attribute(self): """Test is_empty() when group has subgroup with an attribute""" - gb1 = GroupBuilder('my_subgroup', attributes={'my_attr': 'attr_value'}) - gb2 = GroupBuilder('gb', {'my_subgroup': gb1}) + gb1 = GroupBuilder("my_subgroup", attributes={"my_attr": "attr_value"}) + gb2 = GroupBuilder("gb", {"my_subgroup": gb1}) self.assertFalse(gb2.is_empty()) def test_is_empty_false_link(self): """Test is_empty() when group has a link""" - gb1 = GroupBuilder('target') - gb2 = GroupBuilder('gb', links={'my_link': LinkBuilder(gb1)}) + gb1 = GroupBuilder("target") + gb2 = GroupBuilder("gb", links={"my_link": LinkBuilder(gb1)}) self.assertFalse(gb2.is_empty()) def test_is_empty_false_group_link(self): """Test is_empty() when group has subgroup with a link""" - gb1 = GroupBuilder('target') - gb2 = GroupBuilder('my_subgroup', links={'my_link': LinkBuilder(gb1)}) - gb3 = GroupBuilder('gb', {'my_subgroup': gb2}) + gb1 = GroupBuilder("target") + gb2 = GroupBuilder("my_subgroup", links={"my_link": LinkBuilder(gb1)}) + gb3 = GroupBuilder("gb", {"my_subgroup": gb2}) self.assertFalse(gb3.is_empty()) class TestDatasetBuilder(TestCase): - def test_constructor(self): - gb1 = GroupBuilder('gb1') + gb1 = GroupBuilder("gb1") db1 = DatasetBuilder( - name='db1', + name="db1", data=[1, 2, 3], dtype=int, - attributes={'attr1': 10}, + attributes={"attr1": 10}, maxshape=10, chunks=True, parent=gb1, - source='source', + source="source", ) - self.assertEqual(db1.name, 'db1') + self.assertEqual(db1.name, "db1") self.assertListEqual(db1.data, [1, 2, 3]) self.assertEqual(db1.dtype, int) - self.assertDictEqual(db1.attributes, {'attr1': 10}) + self.assertDictEqual(db1.attributes, {"attr1": 10}) self.assertEqual(db1.maxshape, 10) self.assertTrue(db1.chunks) self.assertIs(db1.parent, gb1) - self.assertEqual(db1.source, 'source') + self.assertEqual(db1.source, "source") def test_constructor_data_builder_no_dtype(self): - db1 = DatasetBuilder(name='db1', dtype=int) - db2 = DatasetBuilder(name='db2', data=db1) + db1 = DatasetBuilder(name="db1", dtype=int) + db2 = DatasetBuilder(name="db2", data=db1) self.assertEqual(db2.dtype, DatasetBuilder.OBJECT_REF_TYPE) def test_constructor_data_builder_dtype(self): - db1 = DatasetBuilder(name='db1', dtype=int) - db2 = DatasetBuilder(name='db2', data=db1, dtype=float) + db1 = DatasetBuilder(name="db1", dtype=int) + db2 = DatasetBuilder(name="db2", data=db1, dtype=float) self.assertEqual(db2.dtype, float) def test_set_data(self): - db1 = DatasetBuilder(name='db1') + db1 = DatasetBuilder(name="db1") db1.data = [4, 5, 6] self.assertEqual(db1.data, [4, 5, 6]) def test_set_dtype(self): - db1 = DatasetBuilder(name='db1') + db1 = DatasetBuilder(name="db1") db1.dtype = float self.assertEqual(db1.dtype, float) def test_overwrite_data(self): - db1 = DatasetBuilder(name='db1', data=[1, 2, 3]) + db1 = DatasetBuilder(name="db1", data=[1, 2, 3]) msg = "Cannot overwrite data." with self.assertRaisesWith(AttributeError, msg): db1.data = [4, 5, 6] def test_overwrite_dtype(self): - db1 = DatasetBuilder(name='db1', data=[1, 2, 3], dtype=int) + db1 = DatasetBuilder(name="db1", data=[1, 2, 3], dtype=int) msg = "Cannot overwrite dtype." with self.assertRaisesWith(AttributeError, msg): db1.dtype = float def test_overwrite_source(self): - db1 = DatasetBuilder(name='db1', data=[1, 2, 3], source='source') - msg = 'Cannot overwrite source.' + db1 = DatasetBuilder(name="db1", data=[1, 2, 3], source="source") + msg = "Cannot overwrite source." with self.assertRaisesWith(AttributeError, msg): - db1.source = 'new source' + db1.source = "new source" def test_overwrite_parent(self): - gb1 = GroupBuilder('gb1') - db1 = DatasetBuilder(name='db1', data=[1, 2, 3], parent=gb1) - msg = 'Cannot overwrite parent.' + gb1 = GroupBuilder("gb1") + db1 = DatasetBuilder(name="db1", data=[1, 2, 3], parent=gb1) + msg = "Cannot overwrite parent." with self.assertRaisesWith(AttributeError, msg): db1.parent = gb1 def test_repr(self): - gb1 = GroupBuilder('gb1') + gb1 = GroupBuilder("gb1") db1 = DatasetBuilder( - name='db1', + name="db1", data=[1, 2, 3], dtype=int, - attributes={'attr2': 10}, + attributes={"attr2": 10}, maxshape=10, chunks=True, parent=gb1, - source='source', + source="source", ) expected = "gb1/db1 DatasetBuilder {'attributes': {'attr2': 10}, 'data': [1, 2, 3]}" self.assertEqual(db1.__repr__(), expected) class TestLinkBuilder(TestCase): - def test_constructor(self): - gb = GroupBuilder('gb1') - db = DatasetBuilder('db1', [1, 2, 3]) - lb = LinkBuilder(db, 'link_name', gb, 'link_source') + gb = GroupBuilder("gb1") + db = DatasetBuilder("db1", [1, 2, 3]) + lb = LinkBuilder(db, "link_name", gb, "link_source") self.assertIs(lb.builder, db) - self.assertEqual(lb.name, 'link_name') + self.assertEqual(lb.name, "link_name") self.assertIs(lb.parent, gb) - self.assertEqual(lb.source, 'link_source') + self.assertEqual(lb.source, "link_source") def test_constructor_no_name(self): - db = DatasetBuilder('db1', [1, 2, 3]) + db = DatasetBuilder("db1", [1, 2, 3]) lb = LinkBuilder(db) self.assertIs(lb.builder, db) - self.assertEqual(lb.name, 'db1') + self.assertEqual(lb.name, "db1") class TestReferenceBuilder(TestCase): - def test_constructor(self): - db = DatasetBuilder('db1', [1, 2, 3]) + db = DatasetBuilder("db1", [1, 2, 3]) rb = ReferenceBuilder(db) self.assertIs(rb.builder, db) class TestRegionBuilder(TestCase): - def test_constructor(self): - db = DatasetBuilder('db1', [1, 2, 3]) + db = DatasetBuilder("db1", [1, 2, 3]) rb = RegionBuilder(slice(1, 3), db) self.assertEqual(rb.region, slice(1, 3)) self.assertIs(rb.builder, db) diff --git a/tests/unit/build_tests/test_classgenerator.py b/tests/unit/build_tests/test_classgenerator.py index 3bc0bf7f9..d32ec54c3 100644 --- a/tests/unit/build_tests/test_classgenerator.py +++ b/tests/unit/build_tests/test_classgenerator.py @@ -1,64 +1,81 @@ -import numpy as np import os import shutil import tempfile -from hdmf.build import TypeMap, CustomClassGenerator +import numpy as np + +from hdmf.build import CustomClassGenerator, TypeMap from hdmf.build.classgenerator import ClassGenerator, MCIClassGenerator -from hdmf.container import Container, Data, MultiContainerInterface, AbstractContainer -from hdmf.spec import GroupSpec, AttributeSpec, DatasetSpec, SpecCatalog, SpecNamespace, NamespaceCatalog, LinkSpec +from hdmf.container import AbstractContainer, Container, Data, MultiContainerInterface +from hdmf.spec import ( + AttributeSpec, + DatasetSpec, + GroupSpec, + LinkSpec, + NamespaceCatalog, + SpecCatalog, + SpecNamespace, +) from hdmf.testing import TestCase -from hdmf.utils import get_docval, docval +from hdmf.utils import docval, get_docval +from ..helpers.utils import ( + CORE_NAMESPACE, + create_load_namespace_yaml, + create_test_type_map, +) from .test_io_map import Bar -from tests.unit.helpers.utils import CORE_NAMESPACE, create_test_type_map, create_load_namespace_yaml class TestClassGenerator(TestCase): - def test_register_generator(self): """Test TypeMap.register_generator and ClassGenerator.register_generator.""" class MyClassGenerator(CustomClassGenerator): - @classmethod def apply_generator_to_field(cls, field_spec, bases, type_map): return True @classmethod - def process_field_spec(cls, classdict, docval_args, parent_cls, attr_name, not_inherited_fields, type_map, - spec): + def process_field_spec( + cls, + classdict, + docval_args, + parent_cls, + attr_name, + not_inherited_fields, + type_map, + spec, + ): # append attr_name to classdict['__custom_fields__'] list - classdict.setdefault('process_field_spec', list()).append(attr_name) + classdict.setdefault("process_field_spec", list()).append(attr_name) @classmethod def post_process(cls, classdict, bases, docval_args, spec): - classdict['post_process'] = True + classdict["post_process"] = True spec = GroupSpec( - doc='A test group specification with a data type', - data_type_def='Baz', - attributes=[ - AttributeSpec(name='attr1', doc='a string attribute', dtype='text') - ] + doc="A test group specification with a data type", + data_type_def="Baz", + attributes=[AttributeSpec(name="attr1", doc="a string attribute", dtype="text")], ) spec_catalog = SpecCatalog() - spec_catalog.register_spec(spec, 'test.yaml') + spec_catalog.register_spec(spec, "test.yaml") namespace = SpecNamespace( - doc='a test namespace', + doc="a test namespace", name=CORE_NAMESPACE, - schema=[{'source': 'test.yaml'}], - version='0.1.0', - catalog=spec_catalog + schema=[{"source": "test.yaml"}], + version="0.1.0", + catalog=spec_catalog, ) namespace_catalog = NamespaceCatalog() namespace_catalog.add_namespace(CORE_NAMESPACE, namespace) type_map = TypeMap(namespace_catalog) type_map.register_generator(MyClassGenerator) - cls = type_map.get_dt_container_cls('Baz', CORE_NAMESPACE) + cls = type_map.get_dt_container_cls("Baz", CORE_NAMESPACE) - self.assertEqual(cls.process_field_spec, ['attr1']) + self.assertEqual(cls.process_field_spec, ["attr1"]) self.assertTrue(cls.post_process) def test_bad_generator(self): @@ -69,111 +86,175 @@ class NotACustomClassGenerator: type_map = TypeMap() - msg = 'Generator <.*> must be a subclass of CustomClassGenerator.' + msg = "Generator <.*> must be a subclass of CustomClassGenerator." with self.assertRaisesRegex(ValueError, msg): type_map.register_generator(NotACustomClassGenerator) def test_no_generators(self): """Test that a ClassGenerator without registered generators does nothing.""" cg = ClassGenerator() - spec = GroupSpec(doc='A test group spec with a data type', data_type_def='Baz') - cls = cg.generate_class(data_type='Baz', spec=spec, parent_cls=Container, attr_names={}, type_map=TypeMap()) + spec = GroupSpec(doc="A test group spec with a data type", data_type_def="Baz") + cls = cg.generate_class( + data_type="Baz", + spec=spec, + parent_cls=Container, + attr_names={}, + type_map=TypeMap(), + ) self.assertEqual(cls.__mro__, (cls, Container, AbstractContainer, object)) - self.assertTrue(hasattr(cls, '__init__')) + self.assertTrue(hasattr(cls, "__init__")) class TestDynamicContainer(TestCase): - def setUp(self): self.bar_spec = GroupSpec( - doc='A test group specification with a data type', - data_type_def='Bar', + doc="A test group specification with a data type", + data_type_def="Bar", datasets=[ DatasetSpec( - doc='a dataset', - dtype='int', - name='data', - attributes=[AttributeSpec(name='attr2', doc='an integer attribute', dtype='int')] + doc="a dataset", + dtype="int", + name="data", + attributes=[ + AttributeSpec( + name="attr2", + doc="an integer attribute", + dtype="int", + ) + ], ) ], - attributes=[AttributeSpec(name='attr1', doc='a string attribute', dtype='text')]) + attributes=[AttributeSpec(name="attr1", doc="a string attribute", dtype="text")], + ) specs = [self.bar_spec] - containers = {'Bar': Bar} + containers = {"Bar": Bar} self.type_map = create_test_type_map(specs, containers) self.spec_catalog = self.type_map.namespace_catalog.get_namespace(CORE_NAMESPACE).catalog def test_dynamic_container_creation(self): - baz_spec = GroupSpec('A test extension with no Container class', - data_type_def='Baz', data_type_inc=self.bar_spec, - attributes=[AttributeSpec('attr3', 'a float attribute', 'float'), - AttributeSpec('attr4', 'another float attribute', 'float')]) - self.spec_catalog.register_spec(baz_spec, 'extension.yaml') - cls = self.type_map.get_dt_container_cls('Baz', CORE_NAMESPACE) - expected_args = {'name', 'data', 'attr1', 'attr2', 'attr3', 'attr4'} + baz_spec = GroupSpec( + "A test extension with no Container class", + data_type_def="Baz", + data_type_inc=self.bar_spec, + attributes=[ + AttributeSpec("attr3", "a float attribute", "float"), + AttributeSpec("attr4", "another float attribute", "float"), + ], + ) + self.spec_catalog.register_spec(baz_spec, "extension.yaml") + cls = self.type_map.get_dt_container_cls("Baz", CORE_NAMESPACE) + expected_args = {"name", "data", "attr1", "attr2", "attr3", "attr4"} received_args = set() for x in get_docval(cls.__init__): - if x['name'] != 'foo': - received_args.add(x['name']) - with self.subTest(name=x['name']): - self.assertNotIn('default', x) + if x["name"] != "foo": + received_args.add(x["name"]) + with self.subTest(name=x["name"]): + self.assertNotIn("default", x) self.assertSetEqual(expected_args, received_args) - self.assertEqual(cls.__name__, 'Baz') + self.assertEqual(cls.__name__, "Baz") self.assertTrue(issubclass(cls, Bar)) def test_dynamic_container_default_name(self): - baz_spec = GroupSpec('doc', default_name='bingo', data_type_def='Baz', - attributes=[AttributeSpec('attr4', 'another float attribute', 'float')]) - self.spec_catalog.register_spec(baz_spec, 'extension.yaml') - cls = self.type_map.get_dt_container_cls('Baz', CORE_NAMESPACE) - inst = cls(attr4=10.) - self.assertEqual(inst.name, 'bingo') + baz_spec = GroupSpec( + "doc", + default_name="bingo", + data_type_def="Baz", + attributes=[AttributeSpec("attr4", "another float attribute", "float")], + ) + self.spec_catalog.register_spec(baz_spec, "extension.yaml") + cls = self.type_map.get_dt_container_cls("Baz", CORE_NAMESPACE) + inst = cls(attr4=10.0) + self.assertEqual(inst.name, "bingo") def test_dynamic_container_creation_defaults(self): - baz_spec = GroupSpec('A test extension with no Container class', - data_type_def='Baz', data_type_inc=self.bar_spec, - attributes=[AttributeSpec('attr3', 'a float attribute', 'float'), - AttributeSpec('attr4', 'another float attribute', 'float')]) - self.spec_catalog.register_spec(baz_spec, 'extension.yaml') - cls = self.type_map.get_dt_container_cls('Baz', CORE_NAMESPACE) - expected_args = {'name', 'data', 'attr1', 'attr2', 'attr3', 'attr4', 'foo'} - received_args = set(map(lambda x: x['name'], get_docval(cls.__init__))) + baz_spec = GroupSpec( + "A test extension with no Container class", + data_type_def="Baz", + data_type_inc=self.bar_spec, + attributes=[ + AttributeSpec("attr3", "a float attribute", "float"), + AttributeSpec("attr4", "another float attribute", "float"), + ], + ) + self.spec_catalog.register_spec(baz_spec, "extension.yaml") + cls = self.type_map.get_dt_container_cls("Baz", CORE_NAMESPACE) + expected_args = { + "name", + "data", + "attr1", + "attr2", + "attr3", + "attr4", + "foo", + } + received_args = set(map(lambda x: x["name"], get_docval(cls.__init__))) self.assertSetEqual(expected_args, received_args) - self.assertEqual(cls.__name__, 'Baz') + self.assertEqual(cls.__name__, "Baz") self.assertTrue(issubclass(cls, Bar)) def test_dynamic_container_constructor(self): - baz_spec = GroupSpec('A test extension with no Container class', - data_type_def='Baz', data_type_inc=self.bar_spec, - attributes=[AttributeSpec('attr3', 'a float attribute', 'float'), - AttributeSpec('attr4', 'another float attribute', 'float')]) - self.spec_catalog.register_spec(baz_spec, 'extension.yaml') - cls = self.type_map.get_dt_container_cls('Baz', CORE_NAMESPACE) + baz_spec = GroupSpec( + "A test extension with no Container class", + data_type_def="Baz", + data_type_inc=self.bar_spec, + attributes=[ + AttributeSpec("attr3", "a float attribute", "float"), + AttributeSpec("attr4", "another float attribute", "float"), + ], + ) + self.spec_catalog.register_spec(baz_spec, "extension.yaml") + cls = self.type_map.get_dt_container_cls("Baz", CORE_NAMESPACE) # TODO: test that constructor works! - inst = cls(name='My Baz', data=[1, 2, 3, 4], attr1='string attribute', attr2=1000, attr3=98.6, attr4=1.0) - self.assertEqual(inst.name, 'My Baz') + inst = cls( + name="My Baz", + data=[1, 2, 3, 4], + attr1="string attribute", + attr2=1000, + attr3=98.6, + attr4=1.0, + ) + self.assertEqual(inst.name, "My Baz") self.assertEqual(inst.data, [1, 2, 3, 4]) - self.assertEqual(inst.attr1, 'string attribute') + self.assertEqual(inst.attr1, "string attribute") self.assertEqual(inst.attr2, 1000) self.assertEqual(inst.attr3, 98.6) self.assertEqual(inst.attr4, 1.0) def test_dynamic_container_constructor_name(self): # name is specified in spec and cannot be changed - baz_spec = GroupSpec('A test extension with no Container class', - data_type_def='Baz', data_type_inc=self.bar_spec, - name='A fixed name', - attributes=[AttributeSpec('attr3', 'a float attribute', 'float'), - AttributeSpec('attr4', 'another float attribute', 'float')]) - self.spec_catalog.register_spec(baz_spec, 'extension.yaml') - cls = self.type_map.get_dt_container_cls('Baz', CORE_NAMESPACE) + baz_spec = GroupSpec( + "A test extension with no Container class", + data_type_def="Baz", + data_type_inc=self.bar_spec, + name="A fixed name", + attributes=[ + AttributeSpec("attr3", "a float attribute", "float"), + AttributeSpec("attr4", "another float attribute", "float"), + ], + ) + self.spec_catalog.register_spec(baz_spec, "extension.yaml") + cls = self.type_map.get_dt_container_cls("Baz", CORE_NAMESPACE) with self.assertRaises(TypeError): - inst = cls(name='My Baz', data=[1, 2, 3, 4], attr1='string attribute', attr2=1000, attr3=98.6, attr4=1.0) + inst = cls( + name="My Baz", + data=[1, 2, 3, 4], + attr1="string attribute", + attr2=1000, + attr3=98.6, + attr4=1.0, + ) - inst = cls(data=[1, 2, 3, 4], attr1='string attribute', attr2=1000, attr3=98.6, attr4=1.0) - self.assertEqual(inst.name, 'A fixed name') + inst = cls( + data=[1, 2, 3, 4], + attr1="string attribute", + attr2=1000, + attr3=98.6, + attr4=1.0, + ) + self.assertEqual(inst.name, "A fixed name") self.assertEqual(inst.data, [1, 2, 3, 4]) - self.assertEqual(inst.attr1, 'string attribute') + self.assertEqual(inst.attr1, "string attribute") self.assertEqual(inst.attr2, 1000) self.assertEqual(inst.attr3, 98.6) self.assertEqual(inst.attr4, 1.0) @@ -181,97 +262,192 @@ def test_dynamic_container_constructor_name(self): def test_dynamic_container_constructor_name_default_name(self): # if both name and default_name are specified, name should be used with self.assertWarns(Warning): - baz_spec = GroupSpec('A test extension with no Container class', - data_type_def='Baz', data_type_inc=self.bar_spec, - name='A fixed name', - default_name='A default name', - attributes=[AttributeSpec('attr3', 'a float attribute', 'float'), - AttributeSpec('attr4', 'another float attribute', 'float')]) - self.spec_catalog.register_spec(baz_spec, 'extension.yaml') - cls = self.type_map.get_dt_container_cls('Baz', CORE_NAMESPACE) - - inst = cls(data=[1, 2, 3, 4], attr1='string attribute', attr2=1000, attr3=98.6, attr4=1.0) - self.assertEqual(inst.name, 'A fixed name') + baz_spec = GroupSpec( + "A test extension with no Container class", + data_type_def="Baz", + data_type_inc=self.bar_spec, + name="A fixed name", + default_name="A default name", + attributes=[ + AttributeSpec("attr3", "a float attribute", "float"), + AttributeSpec("attr4", "another float attribute", "float"), + ], + ) + self.spec_catalog.register_spec(baz_spec, "extension.yaml") + cls = self.type_map.get_dt_container_cls("Baz", CORE_NAMESPACE) + + inst = cls( + data=[1, 2, 3, 4], + attr1="string attribute", + attr2=1000, + attr3=98.6, + attr4=1.0, + ) + self.assertEqual(inst.name, "A fixed name") def test_dynamic_container_composition(self): - baz_spec2 = GroupSpec('A composition inside', data_type_def='Baz2', - data_type_inc=self.bar_spec, - attributes=[ - AttributeSpec('attr3', 'a float attribute', 'float'), - AttributeSpec('attr4', 'another float attribute', 'float')]) - - baz_spec1 = GroupSpec('A composition test outside', data_type_def='Baz1', data_type_inc=self.bar_spec, - attributes=[AttributeSpec('attr3', 'a float attribute', 'float'), - AttributeSpec('attr4', 'another float attribute', 'float')], - groups=[GroupSpec('A composition inside', data_type_inc='Baz2')]) - self.spec_catalog.register_spec(baz_spec1, 'extension.yaml') - self.spec_catalog.register_spec(baz_spec2, 'extension.yaml') - Baz2 = self.type_map.get_dt_container_cls('Baz2', CORE_NAMESPACE) - Baz1 = self.type_map.get_dt_container_cls('Baz1', CORE_NAMESPACE) - Baz1(name='My Baz', data=[1, 2, 3, 4], attr1='string attribute', attr2=1000, attr3=98.6, attr4=1.0, - baz2=Baz2(name='My Baz', data=[1, 2, 3, 4], attr1='string attribute', attr2=1000, attr3=98.6, attr4=1.0)) - - Bar = self.type_map.get_dt_container_cls('Bar', CORE_NAMESPACE) - bar = Bar(name='My Bar', data=[1, 2, 3, 4], attr1='string attribute', attr2=1000) + baz_spec2 = GroupSpec( + "A composition inside", + data_type_def="Baz2", + data_type_inc=self.bar_spec, + attributes=[ + AttributeSpec("attr3", "a float attribute", "float"), + AttributeSpec("attr4", "another float attribute", "float"), + ], + ) + + baz_spec1 = GroupSpec( + "A composition test outside", + data_type_def="Baz1", + data_type_inc=self.bar_spec, + attributes=[ + AttributeSpec("attr3", "a float attribute", "float"), + AttributeSpec("attr4", "another float attribute", "float"), + ], + groups=[GroupSpec("A composition inside", data_type_inc="Baz2")], + ) + self.spec_catalog.register_spec(baz_spec1, "extension.yaml") + self.spec_catalog.register_spec(baz_spec2, "extension.yaml") + Baz2 = self.type_map.get_dt_container_cls("Baz2", CORE_NAMESPACE) + Baz1 = self.type_map.get_dt_container_cls("Baz1", CORE_NAMESPACE) + Baz1( + name="My Baz", + data=[1, 2, 3, 4], + attr1="string attribute", + attr2=1000, + attr3=98.6, + attr4=1.0, + baz2=Baz2( + name="My Baz", + data=[1, 2, 3, 4], + attr1="string attribute", + attr2=1000, + attr3=98.6, + attr4=1.0, + ), + ) + + Bar = self.type_map.get_dt_container_cls("Bar", CORE_NAMESPACE) + bar = Bar( + name="My Bar", + data=[1, 2, 3, 4], + attr1="string attribute", + attr2=1000, + ) with self.assertRaises(TypeError): - Baz1(name='My Baz', data=[1, 2, 3, 4], attr1='string attribute', attr2=1000, attr3=98.6, attr4=1.0, - baz2=bar) + Baz1( + name="My Baz", + data=[1, 2, 3, 4], + attr1="string attribute", + attr2=1000, + attr3=98.6, + attr4=1.0, + baz2=bar, + ) def test_dynamic_container_composition_reverse_order(self): - baz_spec2 = GroupSpec('A composition inside', data_type_def='Baz2', - data_type_inc=self.bar_spec, - attributes=[ - AttributeSpec('attr3', 'a float attribute', 'float'), - AttributeSpec('attr4', 'another float attribute', 'float')]) - - baz_spec1 = GroupSpec('A composition test outside', data_type_def='Baz1', data_type_inc=self.bar_spec, - attributes=[AttributeSpec('attr3', 'a float attribute', 'float'), - AttributeSpec('attr4', 'another float attribute', 'float')], - groups=[GroupSpec('A composition inside', data_type_inc='Baz2')]) - self.spec_catalog.register_spec(baz_spec1, 'extension.yaml') - self.spec_catalog.register_spec(baz_spec2, 'extension.yaml') - Baz1 = self.type_map.get_dt_container_cls('Baz1', CORE_NAMESPACE) - Baz2 = self.type_map.get_dt_container_cls('Baz2', CORE_NAMESPACE) - Baz1(name='My Baz', data=[1, 2, 3, 4], attr1='string attribute', attr2=1000, attr3=98.6, attr4=1.0, - baz2=Baz2(name='My Baz', data=[1, 2, 3, 4], attr1='string attribute', attr2=1000, attr3=98.6, attr4=1.0)) - - Bar = self.type_map.get_dt_container_cls('Bar', CORE_NAMESPACE) - bar = Bar(name='My Bar', data=[1, 2, 3, 4], attr1='string attribute', attr2=1000) + baz_spec2 = GroupSpec( + "A composition inside", + data_type_def="Baz2", + data_type_inc=self.bar_spec, + attributes=[ + AttributeSpec("attr3", "a float attribute", "float"), + AttributeSpec("attr4", "another float attribute", "float"), + ], + ) + + baz_spec1 = GroupSpec( + "A composition test outside", + data_type_def="Baz1", + data_type_inc=self.bar_spec, + attributes=[ + AttributeSpec("attr3", "a float attribute", "float"), + AttributeSpec("attr4", "another float attribute", "float"), + ], + groups=[GroupSpec("A composition inside", data_type_inc="Baz2")], + ) + self.spec_catalog.register_spec(baz_spec1, "extension.yaml") + self.spec_catalog.register_spec(baz_spec2, "extension.yaml") + Baz1 = self.type_map.get_dt_container_cls("Baz1", CORE_NAMESPACE) + Baz2 = self.type_map.get_dt_container_cls("Baz2", CORE_NAMESPACE) + Baz1( + name="My Baz", + data=[1, 2, 3, 4], + attr1="string attribute", + attr2=1000, + attr3=98.6, + attr4=1.0, + baz2=Baz2( + name="My Baz", + data=[1, 2, 3, 4], + attr1="string attribute", + attr2=1000, + attr3=98.6, + attr4=1.0, + ), + ) + + Bar = self.type_map.get_dt_container_cls("Bar", CORE_NAMESPACE) + bar = Bar( + name="My Bar", + data=[1, 2, 3, 4], + attr1="string attribute", + attr2=1000, + ) with self.assertRaises(TypeError): - Baz1(name='My Baz', data=[1, 2, 3, 4], attr1='string attribute', - attr2=1000, attr3=98.6, attr4=1.0, baz2=bar) + Baz1( + name="My Baz", + data=[1, 2, 3, 4], + attr1="string attribute", + attr2=1000, + attr3=98.6, + attr4=1.0, + baz2=bar, + ) def test_dynamic_container_composition_missing_type(self): - baz_spec1 = GroupSpec('A composition test outside', data_type_def='Baz1', data_type_inc=self.bar_spec, - attributes=[AttributeSpec('attr3', 'a float attribute', 'float'), - AttributeSpec('attr4', 'another float attribute', 'float')], - groups=[GroupSpec('A composition inside', data_type_inc='Baz2')]) - self.spec_catalog.register_spec(baz_spec1, 'extension.yaml') + baz_spec1 = GroupSpec( + "A composition test outside", + data_type_def="Baz1", + data_type_inc=self.bar_spec, + attributes=[ + AttributeSpec("attr3", "a float attribute", "float"), + AttributeSpec("attr4", "another float attribute", "float"), + ], + groups=[GroupSpec("A composition inside", data_type_inc="Baz2")], + ) + self.spec_catalog.register_spec(baz_spec1, "extension.yaml") msg = "No specification for 'Baz2' in namespace 'test_core'" with self.assertRaisesWith(ValueError, msg): - self.type_map.get_dt_container_cls('Baz1', CORE_NAMESPACE) + self.type_map.get_dt_container_cls("Baz1", CORE_NAMESPACE) def test_dynamic_container_fixed_name(self): """Test that dynamic class generation for an extended type with a fixed name works.""" - baz_spec = GroupSpec('A test extension with no Container class', - data_type_def='Baz', data_type_inc=self.bar_spec, name='Baz') - self.spec_catalog.register_spec(baz_spec, 'extension.yaml') - Baz = self.type_map.get_dt_container_cls('Baz', CORE_NAMESPACE) - obj = Baz(data=[1, 2, 3, 4], attr1='string attribute', attr2=1000) - self.assertEqual(obj.name, 'Baz') + baz_spec = GroupSpec( + "A test extension with no Container class", + data_type_def="Baz", + data_type_inc=self.bar_spec, + name="Baz", + ) + self.spec_catalog.register_spec(baz_spec, "extension.yaml") + Baz = self.type_map.get_dt_container_cls("Baz", CORE_NAMESPACE) + obj = Baz(data=[1, 2, 3, 4], attr1="string attribute", attr2=1000) + self.assertEqual(obj.name, "Baz") def test_dynamic_container_super_init_fixed_value(self): """Test that dynamic class generation when the superclass init does not include all fields works""" class FixedAttrBar(Bar): - @docval({'name': 'name', 'type': str, 'doc': 'the name of this Bar'}, - {'name': 'data', 'type': ('data', 'array_data'), 'doc': 'some data'}, - {'name': 'attr2', 'type': int, 'doc': 'another attribute'}, - {'name': 'attr3', 'type': float, 'doc': 'a third attribute', 'default': 3.14}, - {'name': 'foo', 'type': 'Foo', 'doc': 'a group', 'default': None}) + @docval( + {"name": "name", "type": str, "doc": "the name of this Bar"}, + {"name": "data", "type": ("data", "array_data"), "doc": "some data"}, + {"name": "attr2", "type": int, "doc": "another attribute"}, + {"name": "attr3", "type": float, "doc": "a third attribute", "default": 3.14}, + {"name": "foo", "type": "Foo", "doc": "a group", "default": None}, + ) def __init__(self, **kwargs): kwargs["attr1"] = "fixed_attr1" super().__init__(**kwargs) @@ -279,19 +455,24 @@ def __init__(self, **kwargs): # overwrite the "Bar" to Bar class mapping from setUp() self.type_map.register_container_type(CORE_NAMESPACE, "Bar", FixedAttrBar) - baz_spec = GroupSpec('A test extension with no Container class', - data_type_def='Baz', data_type_inc=self.bar_spec, - attributes=[AttributeSpec('attr3', 'a float attribute', 'float'), - AttributeSpec('attr4', 'another float attribute', 'float')]) - self.spec_catalog.register_spec(baz_spec, 'extension.yaml') - cls = self.type_map.get_dt_container_cls('Baz', CORE_NAMESPACE) - expected_args = {'name', 'data', 'attr2', 'attr3', 'attr4'} + baz_spec = GroupSpec( + "A test extension with no Container class", + data_type_def="Baz", + data_type_inc=self.bar_spec, + attributes=[ + AttributeSpec("attr3", "a float attribute", "float"), + AttributeSpec("attr4", "another float attribute", "float"), + ], + ) + self.spec_catalog.register_spec(baz_spec, "extension.yaml") + cls = self.type_map.get_dt_container_cls("Baz", CORE_NAMESPACE) + expected_args = {"name", "data", "attr2", "attr3", "attr4"} received_args = set() for x in get_docval(cls.__init__): - if x['name'] != 'foo': - received_args.add(x['name']) - with self.subTest(name=x['name']): - self.assertNotIn('default', x) + if x["name"] != "foo": + received_args.add(x["name"]) + with self.subTest(name=x["name"]): + self.assertNotIn("default", x) self.assertSetEqual(expected_args, received_args) self.assertTrue(issubclass(cls, FixedAttrBar)) inst = cls(name="My Baz", data=[1, 2, 3, 4], attr2=1000, attr3=98.6, attr4=1.0) @@ -299,140 +480,144 @@ def __init__(self, **kwargs): def test_multi_container_spec(self): multi_spec = GroupSpec( - doc='A test extension that contains a multi', - data_type_def='Multi', - groups=[ - GroupSpec(data_type_inc=self.bar_spec, doc='test multi', quantity='*') - ], - attributes=[ - AttributeSpec(name='attr3', doc='a float attribute', dtype='float') - ] + doc="A test extension that contains a multi", + data_type_def="Multi", + groups=[GroupSpec(data_type_inc=self.bar_spec, doc="test multi", quantity="*")], + attributes=[AttributeSpec(name="attr3", doc="a float attribute", dtype="float")], ) - self.spec_catalog.register_spec(multi_spec, 'extension.yaml') - Bar = self.type_map.get_dt_container_cls('Bar', CORE_NAMESPACE) - Multi = self.type_map.get_dt_container_cls('Multi', CORE_NAMESPACE) + self.spec_catalog.register_spec(multi_spec, "extension.yaml") + Bar = self.type_map.get_dt_container_cls("Bar", CORE_NAMESPACE) + Multi = self.type_map.get_dt_container_cls("Multi", CORE_NAMESPACE) assert issubclass(Multi, MultiContainerInterface) assert Multi.__clsconf__ == [ dict( - attr='bars', + attr="bars", type=Bar, - add='add_bars', - get='get_bars', - create='create_bars' + add="add_bars", + get="get_bars", + create="create_bars", ) ] multi = Multi( - name='my_multi', - bars=[Bar(name='my_bar', data=list(range(10)), attr1='value1', attr2=10)], - attr3=5. + name="my_multi", + bars=[ + Bar( + name="my_bar", + data=list(range(10)), + attr1="value1", + attr2=10, + ) + ], + attr3=5.0, ) - assert multi.bars['my_bar'] == Bar(name='my_bar', data=list(range(10)), attr1='value1', attr2=10) - assert multi.attr3 == 5. + assert multi.bars["my_bar"] == Bar(name="my_bar", data=list(range(10)), attr1="value1", attr2=10) + assert multi.attr3 == 5.0 def test_multi_container_spec_with_inc(self): multi_spec = GroupSpec( - doc='A test extension that contains a multi', - data_type_def='Multi', + doc="A test extension that contains a multi", + data_type_def="Multi", data_type_inc=self.bar_spec, - groups=[ - GroupSpec(data_type_inc=self.bar_spec, doc='test multi', quantity='*') - ], - attributes=[ - AttributeSpec(name='attr3', doc='a float attribute', dtype='float') - ] + groups=[GroupSpec(data_type_inc=self.bar_spec, doc="test multi", quantity="*")], + attributes=[AttributeSpec(name="attr3", doc="a float attribute", dtype="float")], ) - self.spec_catalog.register_spec(multi_spec, 'extension.yaml') - Bar = self.type_map.get_dt_container_cls('Bar', CORE_NAMESPACE) - Multi = self.type_map.get_dt_container_cls('Multi', CORE_NAMESPACE) + self.spec_catalog.register_spec(multi_spec, "extension.yaml") + Bar = self.type_map.get_dt_container_cls("Bar", CORE_NAMESPACE) + Multi = self.type_map.get_dt_container_cls("Multi", CORE_NAMESPACE) assert issubclass(Multi, MultiContainerInterface) assert issubclass(Multi, Bar) - assert Multi.__mro__ == (Multi, Bar, MultiContainerInterface, Container, AbstractContainer, object) + assert Multi.__mro__ == ( + Multi, + Bar, + MultiContainerInterface, + Container, + AbstractContainer, + object, + ) assert Multi.__clsconf__ == [ dict( - attr='bars', + attr="bars", type=Bar, - add='add_bars', - get='get_bars', - create='create_bars' + add="add_bars", + get="get_bars", + create="create_bars", ) ] multi = Multi( - name='my_multi', - bars=[Bar(name='my_bar', data=list(range(10)), attr1='value1', attr2=10)], + name="my_multi", + bars=[ + Bar( + name="my_bar", + data=list(range(10)), + attr1="value1", + attr2=10, + ) + ], data=list(range(10)), # from data_type_inc: Bar - attr1='value1', # from data_type_inc: Bar + attr1="value1", # from data_type_inc: Bar attr2=10, # from data_type_inc: Bar - attr3=5. + attr3=5.0, ) assert multi.data == list(range(10)) - assert multi.attr1 == 'value1' + assert multi.attr1 == "value1" assert multi.attr2 == 10 - assert multi.bars['my_bar'] == Bar(name='my_bar', data=list(range(10)), attr1='value1', attr2=10) - assert multi.attr3 == 5. + assert multi.bars["my_bar"] == Bar(name="my_bar", data=list(range(10)), attr1="value1", attr2=10) + assert multi.attr3 == 5.0 def test_multi_container_spec_zero_or_more(self): multi_spec = GroupSpec( - doc='A test extension that contains a multi', - data_type_def='Multi', - groups=[ - GroupSpec(data_type_inc=self.bar_spec, doc='test multi', quantity='*') - ], - attributes=[ - AttributeSpec(name='attr3', doc='a float attribute', dtype='float') - ] - ) - self.spec_catalog.register_spec(multi_spec, 'extension.yaml') - Multi = self.type_map.get_dt_container_cls('Multi', CORE_NAMESPACE) - multi = Multi( - name='my_multi', - attr3=5. + doc="A test extension that contains a multi", + data_type_def="Multi", + groups=[GroupSpec(data_type_inc=self.bar_spec, doc="test multi", quantity="*")], + attributes=[AttributeSpec(name="attr3", doc="a float attribute", dtype="float")], ) + self.spec_catalog.register_spec(multi_spec, "extension.yaml") + Multi = self.type_map.get_dt_container_cls("Multi", CORE_NAMESPACE) + multi = Multi(name="my_multi", attr3=5.0) assert len(multi.bars) == 0 def test_multi_container_spec_one_or_more_missing(self): multi_spec = GroupSpec( - doc='A test extension that contains a multi', - data_type_def='Multi', - groups=[ - GroupSpec(data_type_inc=self.bar_spec, doc='test multi', quantity='+') - ], - attributes=[ - AttributeSpec(name='attr3', doc='a float attribute', dtype='float') - ] + doc="A test extension that contains a multi", + data_type_def="Multi", + groups=[GroupSpec(data_type_inc=self.bar_spec, doc="test multi", quantity="+")], + attributes=[AttributeSpec(name="attr3", doc="a float attribute", dtype="float")], ) - self.spec_catalog.register_spec(multi_spec, 'extension.yaml') - Multi = self.type_map.get_dt_container_cls('Multi', CORE_NAMESPACE) - with self.assertRaisesWith(TypeError, "MCIClassGenerator.set_init..__init__: missing argument 'bars'"): - Multi( - name='my_multi', - attr3=5. - ) + self.spec_catalog.register_spec(multi_spec, "extension.yaml") + Multi = self.type_map.get_dt_container_cls("Multi", CORE_NAMESPACE) + with self.assertRaisesWith( + TypeError, + "MCIClassGenerator.set_init..__init__: missing argument 'bars'", + ): + Multi(name="my_multi", attr3=5.0) def test_multi_container_spec_one_or_more_ok(self): multi_spec = GroupSpec( - doc='A test extension that contains a multi', - data_type_def='Multi', - groups=[ - GroupSpec(data_type_inc=self.bar_spec, doc='test multi', quantity='+') - ], - attributes=[ - AttributeSpec(name='attr3', doc='a float attribute', dtype='float') - ] + doc="A test extension that contains a multi", + data_type_def="Multi", + groups=[GroupSpec(data_type_inc=self.bar_spec, doc="test multi", quantity="+")], + attributes=[AttributeSpec(name="attr3", doc="a float attribute", dtype="float")], ) - self.spec_catalog.register_spec(multi_spec, 'extension.yaml') - Multi = self.type_map.get_dt_container_cls('Multi', CORE_NAMESPACE) + self.spec_catalog.register_spec(multi_spec, "extension.yaml") + Multi = self.type_map.get_dt_container_cls("Multi", CORE_NAMESPACE) multi = Multi( - name='my_multi', - bars=[Bar(name='my_bar', data=list(range(10)), attr1='value1', attr2=10)], - attr3=5. + name="my_multi", + bars=[ + Bar( + name="my_bar", + data=list(range(10)), + attr1="value1", + attr2=10, + ) + ], + attr3=5.0, ) assert len(multi.bars) == 1 class TestGetClassSeparateNamespace(TestCase): - def setUp(self): self.test_dir = tempfile.mkdtemp() if os.path.exists(self.test_dir): # start clean @@ -440,15 +625,13 @@ def setUp(self): os.mkdir(self.test_dir) self.bar_spec = GroupSpec( - doc='A test group specification with a data type', - data_type_def='Bar', - datasets=[ - DatasetSpec(name='data', doc='a dataset', dtype='int') - ], + doc="A test group specification with a data type", + data_type_def="Bar", + datasets=[DatasetSpec(name="data", doc="a dataset", dtype="int")], attributes=[ - AttributeSpec(name='attr1', doc='a string attribute', dtype='text'), - AttributeSpec(name='attr2', doc='an integer attribute', dtype='int') - ] + AttributeSpec(name="attr1", doc="a string attribute", dtype="text"), + AttributeSpec(name="attr2", doc="an integer attribute", dtype="int"), + ], ) self.type_map = TypeMap() create_load_namespace_yaml( @@ -456,7 +639,7 @@ def setUp(self): specs=[self.bar_spec], output_dir=self.test_dir, incl_types=dict(), - type_map=self.type_map + type_map=self.type_map, ) def tearDown(self): @@ -464,78 +647,76 @@ def tearDown(self): def test_get_class_separate_ns(self): """Test that get_class correctly sets the name and type hierarchy across namespaces.""" - self.type_map.register_container_type(CORE_NAMESPACE, 'Bar', Bar) + self.type_map.register_container_type(CORE_NAMESPACE, "Bar", Bar) baz_spec = GroupSpec( - doc='A test extension', - data_type_def='Baz', - data_type_inc='Bar', + doc="A test extension", + data_type_def="Baz", + data_type_inc="Bar", ) create_load_namespace_yaml( - namespace_name='ndx-test', + namespace_name="ndx-test", specs=[baz_spec], output_dir=self.test_dir, - incl_types={CORE_NAMESPACE: ['Bar']}, - type_map=self.type_map + incl_types={CORE_NAMESPACE: ["Bar"]}, + type_map=self.type_map, ) - cls = self.type_map.get_dt_container_cls('Baz', 'ndx-test') - self.assertEqual(cls.__name__, 'Baz') + cls = self.type_map.get_dt_container_cls("Baz", "ndx-test") + self.assertEqual(cls.__name__, "Baz") self.assertTrue(issubclass(cls, Bar)) def _build_separate_namespaces(self): # create an empty extension to test ClassGenerator._get_container_type resolution # the Bar class has not been mapped yet to the bar spec - qux_spec = DatasetSpec( - doc='A test extension', - data_type_def='Qux' - ) - spam_spec = DatasetSpec( - doc='A test extension', - data_type_def='Spam' - ) + qux_spec = DatasetSpec(doc="A test extension", data_type_def="Qux") + spam_spec = DatasetSpec(doc="A test extension", data_type_def="Spam") create_load_namespace_yaml( - namespace_name='ndx-qux', + namespace_name="ndx-qux", specs=[qux_spec, spam_spec], output_dir=self.test_dir, incl_types={}, - type_map=self.type_map + type_map=self.type_map, ) # resolve Spam first so that ndx-qux is resolved first - self.type_map.get_dt_container_cls('Spam', 'ndx-qux') + self.type_map.get_dt_container_cls("Spam", "ndx-qux") baz_spec = GroupSpec( - doc='A test extension', - data_type_def='Baz', - data_type_inc='Bar', + doc="A test extension", + data_type_def="Baz", + data_type_inc="Bar", groups=[ - GroupSpec(data_type_inc='Qux', doc='a qux', quantity='?'), - GroupSpec(data_type_inc='Bar', doc='a bar', quantity='?') - ] + GroupSpec(data_type_inc="Qux", doc="a qux", quantity="?"), + GroupSpec(data_type_inc="Bar", doc="a bar", quantity="?"), + ], ) create_load_namespace_yaml( - namespace_name='ndx-test', + namespace_name="ndx-test", specs=[baz_spec], output_dir=self.test_dir, - incl_types={ - CORE_NAMESPACE: ['Bar'], - 'ndx-qux': ['Qux'] - }, - type_map=self.type_map + incl_types={CORE_NAMESPACE: ["Bar"], "ndx-qux": ["Qux"]}, + type_map=self.type_map, ) def _check_classes(self, baz_cls, bar_cls, bar_cls2, qux_cls, qux_cls2): - self.assertEqual(qux_cls.__name__, 'Qux') - self.assertEqual(baz_cls.__name__, 'Baz') - self.assertEqual(bar_cls.__name__, 'Bar') + self.assertEqual(qux_cls.__name__, "Qux") + self.assertEqual(baz_cls.__name__, "Baz") + self.assertEqual(bar_cls.__name__, "Bar") self.assertIs(bar_cls, bar_cls2) # same class, two different namespaces self.assertIs(qux_cls, qux_cls2) self.assertTrue(issubclass(qux_cls, Data)) self.assertTrue(issubclass(baz_cls, bar_cls)) self.assertTrue(issubclass(bar_cls, Container)) - qux_inst = qux_cls(name='qux_name', data=[1]) - bar_inst = bar_cls(name='bar_name', data=100, attr1='a string', attr2=10) - baz_inst = baz_cls(name='baz_name', qux=qux_inst, bar=bar_inst, data=100, attr1='a string', attr2=10) + qux_inst = qux_cls(name="qux_name", data=[1]) + bar_inst = bar_cls(name="bar_name", data=100, attr1="a string", attr2=10) + baz_inst = baz_cls( + name="baz_name", + qux=qux_inst, + bar=bar_inst, + data=100, + attr1="a string", + attr2=10, + ) self.assertIs(baz_inst.qux, qux_inst) def test_get_class_include_from_separate_ns_1(self): @@ -547,11 +728,11 @@ def test_get_class_include_from_separate_ns_1(self): """ self._build_separate_namespaces() - baz_cls = self.type_map.get_dt_container_cls('Baz', 'ndx-test') # Qux and Bar are not yet resolved - bar_cls = self.type_map.get_dt_container_cls('Bar', 'ndx-test') - bar_cls2 = self.type_map.get_dt_container_cls('Bar', CORE_NAMESPACE) - qux_cls = self.type_map.get_dt_container_cls('Qux', 'ndx-test') - qux_cls2 = self.type_map.get_dt_container_cls('Qux', 'ndx-qux') + baz_cls = self.type_map.get_dt_container_cls("Baz", "ndx-test") # Qux and Bar are not yet resolved + bar_cls = self.type_map.get_dt_container_cls("Bar", "ndx-test") + bar_cls2 = self.type_map.get_dt_container_cls("Bar", CORE_NAMESPACE) + qux_cls = self.type_map.get_dt_container_cls("Qux", "ndx-test") + qux_cls2 = self.type_map.get_dt_container_cls("Qux", "ndx-qux") self._check_classes(baz_cls, bar_cls, bar_cls2, qux_cls, qux_cls2) @@ -564,11 +745,11 @@ def test_get_class_include_from_separate_ns_2(self): """ self._build_separate_namespaces() - baz_cls = self.type_map.get_dt_container_cls('Baz', 'ndx-test') # Qux and Bar are not yet resolved - bar_cls2 = self.type_map.get_dt_container_cls('Bar', CORE_NAMESPACE) - bar_cls = self.type_map.get_dt_container_cls('Bar', 'ndx-test') - qux_cls = self.type_map.get_dt_container_cls('Qux', 'ndx-test') - qux_cls2 = self.type_map.get_dt_container_cls('Qux', 'ndx-qux') + baz_cls = self.type_map.get_dt_container_cls("Baz", "ndx-test") # Qux and Bar are not yet resolved + bar_cls2 = self.type_map.get_dt_container_cls("Bar", CORE_NAMESPACE) + bar_cls = self.type_map.get_dt_container_cls("Bar", "ndx-test") + qux_cls = self.type_map.get_dt_container_cls("Qux", "ndx-test") + qux_cls2 = self.type_map.get_dt_container_cls("Qux", "ndx-qux") self._check_classes(baz_cls, bar_cls, bar_cls2, qux_cls, qux_cls2) @@ -581,11 +762,11 @@ def test_get_class_include_from_separate_ns_3(self): """ self._build_separate_namespaces() - baz_cls = self.type_map.get_dt_container_cls('Baz', 'ndx-test') # Qux and Bar are not yet resolved - bar_cls = self.type_map.get_dt_container_cls('Bar', 'ndx-test') - bar_cls2 = self.type_map.get_dt_container_cls('Bar', CORE_NAMESPACE) - qux_cls2 = self.type_map.get_dt_container_cls('Qux', 'ndx-qux') - qux_cls = self.type_map.get_dt_container_cls('Qux', 'ndx-test') + baz_cls = self.type_map.get_dt_container_cls("Baz", "ndx-test") # Qux and Bar are not yet resolved + bar_cls = self.type_map.get_dt_container_cls("Bar", "ndx-test") + bar_cls2 = self.type_map.get_dt_container_cls("Bar", CORE_NAMESPACE) + qux_cls2 = self.type_map.get_dt_container_cls("Qux", "ndx-qux") + qux_cls = self.type_map.get_dt_container_cls("Qux", "ndx-test") self._check_classes(baz_cls, bar_cls, bar_cls2, qux_cls, qux_cls2) @@ -598,11 +779,11 @@ def test_get_class_include_from_separate_ns_4(self): """ self._build_separate_namespaces() - baz_cls = self.type_map.get_dt_container_cls('Baz', 'ndx-test') # Qux and Bar are not yet resolved - bar_cls2 = self.type_map.get_dt_container_cls('Bar', CORE_NAMESPACE) - bar_cls = self.type_map.get_dt_container_cls('Bar', 'ndx-test') - qux_cls2 = self.type_map.get_dt_container_cls('Qux', 'ndx-qux') - qux_cls = self.type_map.get_dt_container_cls('Qux', 'ndx-test') + baz_cls = self.type_map.get_dt_container_cls("Baz", "ndx-test") # Qux and Bar are not yet resolved + bar_cls2 = self.type_map.get_dt_container_cls("Bar", CORE_NAMESPACE) + bar_cls = self.type_map.get_dt_container_cls("Bar", "ndx-test") + qux_cls2 = self.type_map.get_dt_container_cls("Qux", "ndx-qux") + qux_cls = self.type_map.get_dt_container_cls("Qux", "ndx-test") self._check_classes(baz_cls, bar_cls, bar_cls2, qux_cls, qux_cls2) @@ -612,38 +793,42 @@ class EmptyBar(Container): class TestBaseProcessFieldSpec(TestCase): - def setUp(self): self.bar_spec = GroupSpec( - doc='A test group specification with a data type', - data_type_def='EmptyBar' + doc="A test group specification with a data type", + data_type_def="EmptyBar", ) self.spec_catalog = SpecCatalog() - self.spec_catalog.register_spec(self.bar_spec, 'test.yaml') - self.namespace = SpecNamespace('a test namespace', CORE_NAMESPACE, - [{'source': 'test.yaml'}], - version='0.1.0', - catalog=self.spec_catalog) + self.spec_catalog.register_spec(self.bar_spec, "test.yaml") + self.namespace = SpecNamespace( + "a test namespace", + CORE_NAMESPACE, + [{"source": "test.yaml"}], + version="0.1.0", + catalog=self.spec_catalog, + ) self.namespace_catalog = NamespaceCatalog() self.namespace_catalog.add_namespace(CORE_NAMESPACE, self.namespace) self.type_map = TypeMap(self.namespace_catalog) - self.type_map.register_container_type(CORE_NAMESPACE, 'EmptyBar', EmptyBar) + self.type_map.register_container_type(CORE_NAMESPACE, "EmptyBar", EmptyBar) def test_update_docval(self): """Test update_docval_args for a variety of data types and mapping configurations.""" spec = GroupSpec( doc="A test group specification with a data type", data_type_def="Baz", - groups=[ - GroupSpec(doc="a group", data_type_inc="EmptyBar", quantity="?") - ], + groups=[GroupSpec(doc="a group", data_type_inc="EmptyBar", quantity="?")], datasets=[ DatasetSpec( doc="a dataset", dtype="int", name="data", attributes=[ - AttributeSpec(name="attr2", doc="an integer attribute", dtype="int") + AttributeSpec( + name="attr2", + doc="an integer attribute", + dtype="int", + ) ], ) ], @@ -655,25 +840,55 @@ def test_update_docval(self): ) expected = [ - {'name': 'data', 'type': (int, np.int32, np.int64), 'doc': 'a dataset'}, - {'name': 'attr1', 'type': str, 'doc': 'a string attribute'}, - {'name': 'attr2', 'type': (int, np.int32, np.int64), 'doc': 'an integer attribute'}, - {'name': 'attr3', 'doc': 'a numeric attribute', - 'type': (float, np.float32, np.float64, np.int8, np.int16, - np.int32, np.int64, int, np.uint8, np.uint16, - np.uint32, np.uint64)}, - {'name': 'attr4', 'doc': 'a float attribute', - 'type': (float, np.float32, np.float64)}, - {'name': 'bar', 'type': EmptyBar, 'doc': 'a group', 'default': None}, + { + "name": "data", + "type": (int, np.int32, np.int64), + "doc": "a dataset", + }, + {"name": "attr1", "type": str, "doc": "a string attribute"}, + { + "name": "attr2", + "type": (int, np.int32, np.int64), + "doc": "an integer attribute", + }, + { + "name": "attr3", + "doc": "a numeric attribute", + "type": ( + float, + np.float32, + np.float64, + np.int8, + np.int16, + np.int32, + np.int64, + int, + np.uint8, + np.uint16, + np.uint32, + np.uint64, + ), + }, + { + "name": "attr4", + "doc": "a float attribute", + "type": (float, np.float32, np.float64), + }, + { + "name": "bar", + "type": EmptyBar, + "doc": "a group", + "default": None, + }, ] not_inherited_fields = { - 'data': spec.get_dataset('data'), - 'attr1': spec.get_attribute('attr1'), - 'attr2': spec.get_dataset('data').get_attribute('attr2'), - 'attr3': spec.get_attribute('attr3'), - 'attr4': spec.get_attribute('attr4'), - 'bar': spec.groups[0] + "data": spec.get_dataset("data"), + "attr1": spec.get_attribute("attr1"), + "attr2": spec.get_dataset("data").get_attribute("attr2"), + "attr3": spec.get_attribute("attr3"), + "attr4": spec.get_attribute("attr4"), + "bar": spec.groups[0], } docval_args = list() @@ -686,367 +901,491 @@ def test_update_docval(self): attr_name=attr_name, not_inherited_fields=not_inherited_fields, type_map=self.type_map, - spec=spec + spec=spec, ) - self.assertListEqual(docval_args, expected[:(i+1)]) # compare with the first i elements of expected + self.assertListEqual(docval_args, expected[: (i + 1)]) # compare with the first i elements of expected def test_update_docval_attr_shape(self): """Test that update_docval_args for an attribute with shape sets the type and shape keys.""" spec = GroupSpec( - doc='A test group specification with a data type', - data_type_def='Baz', + doc="A test group specification with a data type", + data_type_def="Baz", attributes=[ - AttributeSpec(name='attr1', doc='a string attribute', dtype='text', shape=[None]) - ] + AttributeSpec( + name="attr1", + doc="a string attribute", + dtype="text", + shape=[None], + ) + ], ) - not_inherited_fields = {'attr1': spec.get_attribute('attr1')} + not_inherited_fields = {"attr1": spec.get_attribute("attr1")} docval_args = list() CustomClassGenerator.process_field_spec( classdict={}, docval_args=docval_args, parent_cls=EmptyBar, # <-- arbitrary class - attr_name='attr1', + attr_name="attr1", not_inherited_fields=not_inherited_fields, type_map=TypeMap(), - spec=spec + spec=spec, ) - expected = [{'name': 'attr1', 'type': ('array_data', 'data'), 'doc': 'a string attribute', 'shape': [None]}] + expected = [ + { + "name": "attr1", + "type": ("array_data", "data"), + "doc": "a string attribute", + "shape": [None], + } + ] self.assertListEqual(docval_args, expected) def test_update_docval_dset_shape(self): """Test that update_docval_args for a dataset with shape sets the type and shape keys.""" spec = GroupSpec( - doc='A test group specification with a data type', - data_type_def='Baz', + doc="A test group specification with a data type", + data_type_def="Baz", datasets=[ - DatasetSpec(name='dset1', doc='a string dataset', dtype='text', shape=[None]) - ] + DatasetSpec( + name="dset1", + doc="a string dataset", + dtype="text", + shape=[None], + ) + ], ) - not_inherited_fields = {'dset1': spec.get_dataset('dset1')} + not_inherited_fields = {"dset1": spec.get_dataset("dset1")} docval_args = list() CustomClassGenerator.process_field_spec( classdict={}, docval_args=docval_args, parent_cls=EmptyBar, # <-- arbitrary class - attr_name='dset1', + attr_name="dset1", not_inherited_fields=not_inherited_fields, type_map=TypeMap(), - spec=spec + spec=spec, ) - expected = [{'name': 'dset1', 'type': ('array_data', 'data'), 'doc': 'a string dataset', 'shape': [None]}] + expected = [ + { + "name": "dset1", + "type": ("array_data", "data"), + "doc": "a string dataset", + "shape": [None], + } + ] self.assertListEqual(docval_args, expected) def test_update_docval_default_value(self): """Test that update_docval_args for an optional field with default value sets the default key.""" spec = GroupSpec( - doc='A test group specification with a data type', - data_type_def='Baz', + doc="A test group specification with a data type", + data_type_def="Baz", attributes=[ - AttributeSpec(name='attr1', doc='a string attribute', dtype='text', required=False, - default_value='value') - ] + AttributeSpec( + name="attr1", + doc="a string attribute", + dtype="text", + required=False, + default_value="value", + ) + ], ) - not_inherited_fields = {'attr1': spec.get_attribute('attr1')} + not_inherited_fields = {"attr1": spec.get_attribute("attr1")} docval_args = list() CustomClassGenerator.process_field_spec( classdict={}, docval_args=docval_args, parent_cls=EmptyBar, # <-- arbitrary class - attr_name='attr1', + attr_name="attr1", not_inherited_fields=not_inherited_fields, type_map=TypeMap(), - spec=spec + spec=spec, ) - expected = [{'name': 'attr1', 'type': str, 'doc': 'a string attribute', 'default': 'value'}] + expected = [ + { + "name": "attr1", + "type": str, + "doc": "a string attribute", + "default": "value", + } + ] self.assertListEqual(docval_args, expected) def test_update_docval_default_value_none(self): """Test that update_docval_args for an optional field sets default: None.""" spec = GroupSpec( - doc='A test group specification with a data type', - data_type_def='Baz', + doc="A test group specification with a data type", + data_type_def="Baz", attributes=[ - AttributeSpec(name='attr1', doc='a string attribute', dtype='text', required=False) - ] + AttributeSpec( + name="attr1", + doc="a string attribute", + dtype="text", + required=False, + ) + ], ) - not_inherited_fields = {'attr1': spec.get_attribute('attr1')} + not_inherited_fields = {"attr1": spec.get_attribute("attr1")} docval_args = list() CustomClassGenerator.process_field_spec( classdict={}, docval_args=docval_args, parent_cls=EmptyBar, # <-- arbitrary class - attr_name='attr1', + attr_name="attr1", not_inherited_fields=not_inherited_fields, type_map=TypeMap(), - spec=spec + spec=spec, ) - expected = [{'name': 'attr1', 'type': str, 'doc': 'a string attribute', 'default': None}] + expected = [ + { + "name": "attr1", + "type": str, + "doc": "a string attribute", + "default": None, + } + ] self.assertListEqual(docval_args, expected) def test_update_docval_default_value_none_required_parent(self): """Test that update_docval_args for an optional field with a required parent sets default: None.""" spec = GroupSpec( - doc='A test group specification with a data type', - data_type_def='Baz', + doc="A test group specification with a data type", + data_type_def="Baz", groups=[ GroupSpec( - name='group1', - doc='required untyped group', + name="group1", + doc="required untyped group", attributes=[ - AttributeSpec(name='attr1', doc='a string attribute', dtype='text', required=False) - ] + AttributeSpec( + name="attr1", + doc="a string attribute", + dtype="text", + required=False, + ) + ], ) - ] + ], ) - not_inherited_fields = {'attr1': spec.get_group('group1').get_attribute('attr1')} + not_inherited_fields = {"attr1": spec.get_group("group1").get_attribute("attr1")} docval_args = list() CustomClassGenerator.process_field_spec( classdict={}, docval_args=docval_args, parent_cls=EmptyBar, # <-- arbitrary class - attr_name='attr1', + attr_name="attr1", not_inherited_fields=not_inherited_fields, type_map=TypeMap(), - spec=spec + spec=spec, ) - expected = [{'name': 'attr1', 'type': str, 'doc': 'a string attribute', 'default': None}] + expected = [ + { + "name": "attr1", + "type": str, + "doc": "a string attribute", + "default": None, + } + ] self.assertListEqual(docval_args, expected) def test_update_docval_required_field_optional_parent(self): """Test that update_docval_args for a required field with an optional parent sets default: None.""" spec = GroupSpec( - doc='A test group specification with a data type', - data_type_def='Baz', + doc="A test group specification with a data type", + data_type_def="Baz", groups=[ GroupSpec( - name='group1', - doc='required untyped group', - attributes=[ - AttributeSpec(name='attr1', doc='a string attribute', dtype='text') - ], - quantity='?' + name="group1", + doc="required untyped group", + attributes=[AttributeSpec(name="attr1", doc="a string attribute", dtype="text")], + quantity="?", ) - ] + ], ) - not_inherited_fields = {'attr1': spec.get_group('group1').get_attribute('attr1')} + not_inherited_fields = {"attr1": spec.get_group("group1").get_attribute("attr1")} docval_args = list() CustomClassGenerator.process_field_spec( classdict={}, docval_args=docval_args, parent_cls=EmptyBar, # <-- arbitrary class - attr_name='attr1', + attr_name="attr1", not_inherited_fields=not_inherited_fields, type_map=TypeMap(), - spec=spec + spec=spec, ) - expected = [{'name': 'attr1', 'type': str, 'doc': 'a string attribute', 'default': None}] + expected = [ + { + "name": "attr1", + "type": str, + "doc": "a string attribute", + "default": None, + } + ] self.assertListEqual(docval_args, expected) def test_process_field_spec_overwrite(self): """Test that docval generation overwrites previous docval args.""" spec = GroupSpec( - doc='A test group specification with a data type', - data_type_def='Baz', + doc="A test group specification with a data type", + data_type_def="Baz", attributes=[ - AttributeSpec(name='attr1', doc='a string attribute', dtype='text', shape=[None]) - ] + AttributeSpec( + name="attr1", + doc="a string attribute", + dtype="text", + shape=[None], + ) + ], ) - not_inherited_fields = {'attr1': spec.get_attribute('attr1')} - - docval_args = [{'name': 'attr1', 'type': ('array_data', 'data'), 'doc': 'a string attribute', - 'shape': [[None], [None, None]]}, # this dict will be overwritten below - {'name': 'attr2', 'type': ('array_data', 'data'), 'doc': 'a string attribute', - 'shape': [[None], [None, None]]}] + not_inherited_fields = {"attr1": spec.get_attribute("attr1")} + + docval_args = [ + { + "name": "attr1", + "type": ("array_data", "data"), + "doc": "a string attribute", + "shape": [[None], [None, None]], + }, # this dict will be overwritten below + { + "name": "attr2", + "type": ("array_data", "data"), + "doc": "a string attribute", + "shape": [[None], [None, None]], + }, + ] CustomClassGenerator.process_field_spec( classdict={}, docval_args=docval_args, parent_cls=EmptyBar, # <-- arbitrary class - attr_name='attr1', + attr_name="attr1", not_inherited_fields=not_inherited_fields, type_map=TypeMap(), - spec=spec + spec=spec, ) - expected = [{'name': 'attr1', 'type': ('array_data', 'data'), 'doc': 'a string attribute', - 'shape': [None]}, - {'name': 'attr2', 'type': ('array_data', 'data'), 'doc': 'a string attribute', - 'shape': [[None], [None, None]]}] + expected = [ + { + "name": "attr1", + "type": ("array_data", "data"), + "doc": "a string attribute", + "shape": [None], + }, + { + "name": "attr2", + "type": ("array_data", "data"), + "doc": "a string attribute", + "shape": [[None], [None, None]], + }, + ] self.assertListEqual(docval_args, expected) def test_process_field_spec_link(self): """Test that processing a link spec does not set child=True in __fields__.""" classdict = {} - not_inherited_fields = {'attr3': LinkSpec(name='attr3', target_type='EmptyBar', doc='a link')} + not_inherited_fields = {"attr3": LinkSpec(name="attr3", target_type="EmptyBar", doc="a link")} CustomClassGenerator.process_field_spec( classdict=classdict, docval_args=[], parent_cls=EmptyBar, # <-- arbitrary class - attr_name='attr3', + attr_name="attr3", not_inherited_fields=not_inherited_fields, type_map=self.type_map, - spec=GroupSpec('dummy', 'doc') + spec=GroupSpec("dummy", "doc"), ) - expected = {'__fields__': [{'name': 'attr3', 'doc': 'a link'}]} + expected = {"__fields__": [{"name": "attr3", "doc": "a link"}]} self.assertDictEqual(classdict, expected) def test_post_process_fixed_name(self): """Test that docval generation for a class with a fixed name does not contain a docval arg for name.""" spec = GroupSpec( - doc='A test group specification with a data type', - data_type_def='Baz', - name='MyBaz', # <-- fixed name + doc="A test group specification with a data type", + data_type_def="Baz", + name="MyBaz", # <-- fixed name attributes=[ AttributeSpec( - name='attr1', - doc='a string attribute', - dtype='text', - shape=[None] + name="attr1", + doc="a string attribute", + dtype="text", + shape=[None], ) - ] + ], ) classdict = {} bases = [Container] - docval_args = [{'name': 'name', 'type': str, 'doc': 'name'}, - {'name': 'attr1', 'type': ('array_data', 'data'), 'doc': 'a string attribute', - 'shape': [None]}] + docval_args = [ + {"name": "name", "type": str, "doc": "name"}, + { + "name": "attr1", + "type": ("array_data", "data"), + "doc": "a string attribute", + "shape": [None], + }, + ] CustomClassGenerator.post_process(classdict, bases, docval_args, spec) - expected = [{'name': 'attr1', 'type': ('array_data', 'data'), 'doc': 'a string attribute', - 'shape': [None]}] + expected = [ + { + "name": "attr1", + "type": ("array_data", "data"), + "doc": "a string attribute", + "shape": [None], + } + ] self.assertListEqual(docval_args, expected) def test_post_process_default_name(self): """Test that docval generation for a class with a default name has the default value for name set.""" spec = GroupSpec( - doc='A test group specification with a data type', - data_type_def='Baz', - default_name='MyBaz', # <-- default name + doc="A test group specification with a data type", + data_type_def="Baz", + default_name="MyBaz", # <-- default name attributes=[ AttributeSpec( - name='attr1', - doc='a string attribute', - dtype='text', - shape=[None] + name="attr1", + doc="a string attribute", + dtype="text", + shape=[None], ) - ] + ], ) classdict = {} bases = [Container] - docval_args = [{'name': 'name', 'type': str, 'doc': 'name'}, - {'name': 'attr1', 'type': ('array_data', 'data'), 'doc': 'a string attribute', - 'shape': [None]}] + docval_args = [ + {"name": "name", "type": str, "doc": "name"}, + { + "name": "attr1", + "type": ("array_data", "data"), + "doc": "a string attribute", + "shape": [None], + }, + ] CustomClassGenerator.post_process(classdict, bases, docval_args, spec) - expected = [{'name': 'name', 'type': str, 'doc': 'name', 'default': 'MyBaz'}, - {'name': 'attr1', 'type': ('array_data', 'data'), 'doc': 'a string attribute', - 'shape': [None]}] + expected = [ + {"name": "name", "type": str, "doc": "name", "default": "MyBaz"}, + { + "name": "attr1", + "type": ("array_data", "data"), + "doc": "a string attribute", + "shape": [None], + }, + ] self.assertListEqual(docval_args, expected) class TestMCIProcessFieldSpec(TestCase): - def setUp(self): bar_spec = GroupSpec( - doc='A test group specification with a data type', - data_type_def='EmptyBar' + doc="A test group specification with a data type", + data_type_def="EmptyBar", ) specs = [bar_spec] - container_classes = {'EmptyBar': EmptyBar} + container_classes = {"EmptyBar": EmptyBar} self.type_map = create_test_type_map(specs, container_classes) def test_update_docval(self): - spec = GroupSpec(data_type_inc='EmptyBar', doc='test multi', quantity='*') + spec = GroupSpec(data_type_inc="EmptyBar", doc="test multi", quantity="*") classdict = dict() docval_args = [] - not_inherited_fields = {'empty_bars': spec} + not_inherited_fields = {"empty_bars": spec} MCIClassGenerator.process_field_spec( classdict=classdict, docval_args=docval_args, parent_cls=Container, - attr_name='empty_bars', + attr_name="empty_bars", not_inherited_fields=not_inherited_fields, type_map=self.type_map, - spec=spec + spec=spec, ) expected = [ dict( - attr='empty_bars', + attr="empty_bars", type=EmptyBar, - add='add_empty_bars', - get='get_empty_bars', - create='create_empty_bars' + add="add_empty_bars", + get="get_empty_bars", + create="create_empty_bars", ) ] - self.assertEqual(classdict['__clsconf__'], expected) + self.assertEqual(classdict["__clsconf__"], expected) def test_update_init_zero_or_more(self): - spec = GroupSpec(data_type_inc='EmptyBar', doc='test multi', quantity='*') + spec = GroupSpec(data_type_inc="EmptyBar", doc="test multi", quantity="*") classdict = dict() docval_args = [] - not_inherited_fields = {'empty_bars': spec} + not_inherited_fields = {"empty_bars": spec} MCIClassGenerator.process_field_spec( classdict=classdict, docval_args=docval_args, parent_cls=Container, - attr_name='empty_bars', + attr_name="empty_bars", not_inherited_fields=not_inherited_fields, type_map=self.type_map, - spec=spec + spec=spec, ) - expected = [{'name': 'empty_bars', 'type': (list, tuple, dict, EmptyBar), 'doc': 'test multi', 'default': None}] + expected = [ + { + "name": "empty_bars", + "type": (list, tuple, dict, EmptyBar), + "doc": "test multi", + "default": None, + } + ] self.assertListEqual(docval_args, expected) def test_update_init_one_or_more(self): - spec = GroupSpec(data_type_inc='EmptyBar', doc='test multi', quantity='+') + spec = GroupSpec(data_type_inc="EmptyBar", doc="test multi", quantity="+") classdict = dict() docval_args = [] - not_inherited_fields = {'empty_bars': spec} + not_inherited_fields = {"empty_bars": spec} MCIClassGenerator.process_field_spec( classdict=classdict, docval_args=docval_args, parent_cls=Bar, - attr_name='empty_bars', + attr_name="empty_bars", not_inherited_fields=not_inherited_fields, type_map=self.type_map, - spec=spec + spec=spec, ) - expected = [{'name': 'empty_bars', 'type': (list, tuple, dict, EmptyBar), 'doc': 'test multi'}] + expected = [ + { + "name": "empty_bars", + "type": (list, tuple, dict, EmptyBar), + "doc": "test multi", + } + ] self.assertListEqual(docval_args, expected) def test_post_process(self): multi_spec = GroupSpec( - doc='A test extension that contains a multi', - data_type_def='Multi', - groups=[ - GroupSpec(data_type_inc='EmptyBar', doc='test multi', quantity='*') - ], + doc="A test extension that contains a multi", + data_type_def="Multi", + groups=[GroupSpec(data_type_inc="EmptyBar", doc="test multi", quantity="*")], ) classdict = dict( __clsconf__=[ dict( - attr='empty_bars', + attr="empty_bars", type=EmptyBar, - add='add_empty_bars', - get='get_empty_bars', - create='create_empty_bars' + add="add_empty_bars", + get="get_empty_bars", + create="create_empty_bars", ) ] ) @@ -1060,21 +1399,19 @@ class Multi1(MultiContainerInterface): pass multi_spec = GroupSpec( - doc='A test extension that contains a multi and extends a multi', - data_type_def='Multi2', - data_type_inc='Multi1', - groups=[ - GroupSpec(data_type_inc='EmptyBar', doc='test multi', quantity='*') - ], + doc="A test extension that contains a multi and extends a multi", + data_type_def="Multi2", + data_type_inc="Multi1", + groups=[GroupSpec(data_type_inc="EmptyBar", doc="test multi", quantity="*")], ) classdict = dict( __clsconf__=[ dict( - attr='empty_bars', + attr="empty_bars", type=EmptyBar, - add='add_empty_bars', - get='get_empty_bars', - create='create_empty_bars' + add="add_empty_bars", + get="get_empty_bars", + create="create_empty_bars", ) ] ) @@ -1088,20 +1425,18 @@ class Multi1(MultiContainerInterface): pass multi_spec = GroupSpec( - doc='A test extension that contains a multi and extends a multi', - data_type_def='Multi1', - groups=[ - GroupSpec(data_type_inc='EmptyBar', doc='test multi', quantity='*') - ], + doc="A test extension that contains a multi and extends a multi", + data_type_def="Multi1", + groups=[GroupSpec(data_type_inc="EmptyBar", doc="test multi", quantity="*")], ) classdict = dict( __clsconf__=[ dict( - attr='empty_bars', + attr="empty_bars", type=EmptyBar, - add='add_empty_bars', - get='get_empty_bars', - create='create_empty_bars' + add="add_empty_bars", + get="get_empty_bars", + create="create_empty_bars", ) ] ) diff --git a/tests/unit/build_tests/test_convert_dtype.py b/tests/unit/build_tests/test_convert_dtype.py index bf9b2a95f..bbf0244af 100644 --- a/tests/unit/build_tests/test_convert_dtype.py +++ b/tests/unit/build_tests/test_convert_dtype.py @@ -1,21 +1,25 @@ from datetime import datetime import numpy as np + from hdmf.backends.hdf5 import H5DataIO from hdmf.build import ObjectMapper from hdmf.data_utils import DataChunkIterator -from hdmf.spec import DatasetSpec, RefSpec, DtypeSpec +from hdmf.spec import DatasetSpec, DtypeSpec, RefSpec from hdmf.testing import TestCase class TestConvertDtype(TestCase): - def test_value_none(self): - spec = DatasetSpec('an example dataset', 'int', name='data') - self.assertTupleEqual(ObjectMapper.convert_dtype(spec, None), (None, 'int')) + spec = DatasetSpec("an example dataset", "int", name="data") + self.assertTupleEqual(ObjectMapper.convert_dtype(spec, None), (None, "int")) - spec = DatasetSpec('an example dataset', RefSpec(reftype='object', target_type='int'), name='data') - self.assertTupleEqual(ObjectMapper.convert_dtype(spec, None), (None, 'object')) + spec = DatasetSpec( + "an example dataset", + RefSpec(reftype="object", target_type="int"), + name="data", + ) + self.assertTupleEqual(ObjectMapper.convert_dtype(spec, None), (None, "object")) # do full matrix test of given value x and spec y, what does convert_dtype return? def test_convert_to_64bit_spec(self): @@ -23,31 +27,77 @@ def test_convert_to_64bit_spec(self): Test that if given any value for a spec with a 64-bit dtype, convert_dtype will convert to the spec type. Also test that if the given value is not the same as the spec, convert_dtype raises a warning. """ - spec_type = 'float64' - value_types = ['double', 'float64'] + spec_type = "float64" + value_types = ["double", "float64"] self._test_convert_alias(spec_type, value_types) - spec_type = 'float64' - value_types = ['float', 'float32', 'long', 'int64', 'int', 'int32', 'int16', 'short', 'int8', 'uint64', 'uint', - 'uint32', 'uint16', 'uint8', 'bool'] + spec_type = "float64" + value_types = [ + "float", + "float32", + "long", + "int64", + "int", + "int32", + "int16", + "short", + "int8", + "uint64", + "uint", + "uint32", + "uint16", + "uint8", + "bool", + ] self._test_convert_higher_precision_helper(spec_type, value_types) - spec_type = 'int64' - value_types = ['long', 'int64'] + spec_type = "int64" + value_types = ["long", "int64"] self._test_convert_alias(spec_type, value_types) - spec_type = 'int64' - value_types = ['double', 'float64', 'float', 'float32', 'int', 'int32', 'int16', 'short', 'int8', 'uint64', - 'uint', 'uint32', 'uint16', 'uint8', 'bool'] + spec_type = "int64" + value_types = [ + "double", + "float64", + "float", + "float32", + "int", + "int32", + "int16", + "short", + "int8", + "uint64", + "uint", + "uint32", + "uint16", + "uint8", + "bool", + ] self._test_convert_higher_precision_helper(spec_type, value_types) - spec_type = 'uint64' - value_types = ['uint64'] + spec_type = "uint64" + value_types = ["uint64"] self._test_convert_alias(spec_type, value_types) - spec_type = 'uint64' - value_types = ['double', 'float64', 'float', 'float32', 'long', 'int64', 'int', 'int32', 'int16', 'short', - 'int8', 'uint', 'uint32', 'uint16', 'uint8', 'bool'] + spec_type = "uint64" + value_types = [ + "double", + "float64", + "float", + "float32", + "long", + "int64", + "int", + "int32", + "int16", + "short", + "int8", + "uint", + "uint32", + "uint16", + "uint8", + "bool", + ] self._test_convert_higher_precision_helper(spec_type, value_types) def test_convert_to_float32_spec(self): @@ -57,18 +107,29 @@ def test_convert_to_float32_spec(self): If given a value that is float32, convert_dtype will convert to float32. If given a value with precision <= float32, convert_dtype will convert to float32 and raise a warning. """ - spec_type = 'float32' - value_types = ['double', 'float64'] + spec_type = "float32" + value_types = ["double", "float64"] self._test_keep_higher_precision_helper(spec_type, value_types) - value_types = ['long', 'int64', 'uint64'] - expected_type = 'float64' + value_types = ["long", "int64", "uint64"] + expected_type = "float64" self._test_change_basetype_helper(spec_type, value_types, expected_type) - value_types = ['float', 'float32'] + value_types = ["float", "float32"] self._test_convert_alias(spec_type, value_types) - value_types = ['int', 'int32', 'int16', 'short', 'int8', 'uint', 'uint32', 'uint16', 'uint8', 'bool'] + value_types = [ + "int", + "int32", + "int16", + "short", + "int8", + "uint", + "uint32", + "uint16", + "uint8", + "bool", + ] self._test_convert_higher_precision_helper(spec_type, value_types) def test_convert_to_int32_spec(self): @@ -78,18 +139,29 @@ def test_convert_to_int32_spec(self): If given a value that is int32, convert_dtype will convert to int32. If given a value with precision <= int32, convert_dtype will convert to int32 and raise a warning. """ - spec_type = 'int32' - value_types = ['int64', 'long'] + spec_type = "int32" + value_types = ["int64", "long"] self._test_keep_higher_precision_helper(spec_type, value_types) - value_types = ['double', 'float64', 'uint64'] - expected_type = 'int64' + value_types = ["double", "float64", "uint64"] + expected_type = "int64" self._test_change_basetype_helper(spec_type, value_types, expected_type) - value_types = ['int', 'int32'] + value_types = ["int", "int32"] self._test_convert_alias(spec_type, value_types) - value_types = ['float', 'float32', 'int16', 'short', 'int8', 'uint', 'uint32', 'uint16', 'uint8', 'bool'] + value_types = [ + "float", + "float32", + "int16", + "short", + "int8", + "uint", + "uint32", + "uint16", + "uint8", + "bool", + ] self._test_convert_higher_precision_helper(spec_type, value_types) def test_convert_to_uint32_spec(self): @@ -99,18 +171,29 @@ def test_convert_to_uint32_spec(self): If given a value that is uint32, convert_dtype will convert to uint32. If given a value with precision <= uint32, convert_dtype will convert to uint32 and raise a warning. """ - spec_type = 'uint32' - value_types = ['uint64'] + spec_type = "uint32" + value_types = ["uint64"] self._test_keep_higher_precision_helper(spec_type, value_types) - value_types = ['double', 'float64', 'long', 'int64'] - expected_type = 'uint64' + value_types = ["double", "float64", "long", "int64"] + expected_type = "uint64" self._test_change_basetype_helper(spec_type, value_types, expected_type) - value_types = ['uint', 'uint32'] + value_types = ["uint", "uint32"] self._test_convert_alias(spec_type, value_types) - value_types = ['float', 'float32', 'int', 'int32', 'int16', 'short', 'int8', 'uint16', 'uint8', 'bool'] + value_types = [ + "float", + "float32", + "int", + "int32", + "int16", + "short", + "int8", + "uint16", + "uint8", + "bool", + ] self._test_convert_higher_precision_helper(spec_type, value_types) def test_convert_to_int16_spec(self): @@ -121,22 +204,22 @@ def test_convert_to_int16_spec(self): If given a value that is int16, convert_dtype will convert to int16. If given a value with precision <= int16, convert_dtype will convert to int16 and raise a warning. """ - spec_type = 'int16' - value_types = ['long', 'int64', 'int', 'int32'] + spec_type = "int16" + value_types = ["long", "int64", "int", "int32"] self._test_keep_higher_precision_helper(spec_type, value_types) - value_types = ['double', 'float64', 'uint64'] - expected_type = 'int64' + value_types = ["double", "float64", "uint64"] + expected_type = "int64" self._test_change_basetype_helper(spec_type, value_types, expected_type) - value_types = ['float', 'float32', 'uint', 'uint32'] - expected_type = 'int32' + value_types = ["float", "float32", "uint", "uint32"] + expected_type = "int32" self._test_change_basetype_helper(spec_type, value_types, expected_type) - value_types = ['int16', 'short'] + value_types = ["int16", "short"] self._test_convert_alias(spec_type, value_types) - value_types = ['int8', 'uint16', 'uint8', 'bool'] + value_types = ["int8", "uint16", "uint8", "bool"] self._test_convert_higher_precision_helper(spec_type, value_types) def test_convert_to_uint16_spec(self): @@ -147,22 +230,22 @@ def test_convert_to_uint16_spec(self): If given a value that is uint16, convert_dtype will convert to uint16. If given a value with precision <= uint16, convert_dtype will convert to uint16 and raise a warning. """ - spec_type = 'uint16' - value_types = ['uint64', 'uint', 'uint32'] + spec_type = "uint16" + value_types = ["uint64", "uint", "uint32"] self._test_keep_higher_precision_helper(spec_type, value_types) - value_types = ['double', 'float64', 'long', 'int64'] - expected_type = 'uint64' + value_types = ["double", "float64", "long", "int64"] + expected_type = "uint64" self._test_change_basetype_helper(spec_type, value_types, expected_type) - value_types = ['float', 'float32', 'int', 'int32'] - expected_type = 'uint32' + value_types = ["float", "float32", "int", "int32"] + expected_type = "uint32" self._test_change_basetype_helper(spec_type, value_types, expected_type) - value_types = ['uint16'] + value_types = ["uint16"] self._test_convert_alias(spec_type, value_types) - value_types = ['int16', 'short', 'int8', 'uint8', 'bool'] + value_types = ["int16", "short", "int8", "uint8", "bool"] self._test_convert_higher_precision_helper(spec_type, value_types) def test_convert_to_bool_spec(self): @@ -171,15 +254,29 @@ def test_convert_to_bool_spec(self): If given a value with type int8/uint8, convert_dtype will convert to bool and raise a warning. Otherwise, convert_dtype will raise an error. """ - spec_type = 'bool' - value_types = ['bool'] + spec_type = "bool" + value_types = ["bool"] self._test_convert_alias(spec_type, value_types) - value_types = ['uint8', 'int8'] + value_types = ["uint8", "int8"] self._test_convert_higher_precision_helper(spec_type, value_types) - value_types = ['double', 'float64', 'float', 'float32', 'long', 'int64', 'int', 'int32', 'int16', 'short', - 'uint64', 'uint', 'uint32', 'uint16'] + value_types = [ + "double", + "float64", + "float", + "float32", + "long", + "int64", + "int", + "int32", + "int16", + "short", + "uint64", + "uint", + "uint32", + "uint16", + ] self._test_convert_mismatch_helper(spec_type, value_types) def _get_type(self, type_str): @@ -187,7 +284,7 @@ def _get_type(self, type_str): def _test_convert_alias(self, spec_type, value_types): data = 1 - spec = DatasetSpec('an example dataset', spec_type, name='data') + spec = DatasetSpec("an example dataset", spec_type, name="data") match = (self._get_type(spec_type)(data), self._get_type(spec_type)) for dtype in value_types: value = self._get_type(dtype)(data) # convert data to given dtype @@ -198,15 +295,17 @@ def _test_convert_alias(self, spec_type, value_types): def _test_convert_higher_precision_helper(self, spec_type, value_types): data = 1 - spec = DatasetSpec('an example dataset', spec_type, name='data') + spec = DatasetSpec("an example dataset", spec_type, name="data") match = (self._get_type(spec_type)(data), self._get_type(spec_type)) for dtype in value_types: value = self._get_type(dtype)(data) # convert data to given dtype with self.subTest(dtype=dtype): s = np.dtype(self._get_type(spec_type)) g = np.dtype(self._get_type(dtype)) - msg = ("Spec 'data': Value with data type %s is being converted to data type %s as specified." - % (g.name, s.name)) + msg = "Spec 'data': Value with data type %s is being converted to data type %s as specified." % ( + g.name, + s.name, + ) with self.assertWarnsWith(UserWarning, msg): ret = ObjectMapper.convert_dtype(spec, value) self.assertTupleEqual(ret, match) @@ -214,7 +313,7 @@ def _test_convert_higher_precision_helper(self, spec_type, value_types): def _test_keep_higher_precision_helper(self, spec_type, value_types): data = 1 - spec = DatasetSpec('an example dataset', spec_type, name='data') + spec = DatasetSpec("an example dataset", spec_type, name="data") for dtype in value_types: value = self._get_type(dtype)(data) match = (value, self._get_type(dtype)) @@ -225,7 +324,7 @@ def _test_keep_higher_precision_helper(self, spec_type, value_types): def _test_change_basetype_helper(self, spec_type, value_types, exp_type): data = 1 - spec = DatasetSpec('an example dataset', spec_type, name='data') + spec = DatasetSpec("an example dataset", spec_type, name="data") match = (self._get_type(exp_type)(data), self._get_type(exp_type)) for dtype in value_types: value = self._get_type(dtype)(data) # convert data to given dtype @@ -233,8 +332,10 @@ def _test_change_basetype_helper(self, spec_type, value_types, exp_type): s = np.dtype(self._get_type(spec_type)) e = np.dtype(self._get_type(exp_type)) g = np.dtype(self._get_type(dtype)) - msg = ("Spec 'data': Value with data type %s is being converted to data type %s " - "(min specification: %s)." % (g.name, e.name, s.name)) + msg = ( + "Spec 'data': Value with data type %s is being converted to data type %s (min specification: %s)." + % (g.name, e.name, s.name) + ) with self.assertWarnsWith(UserWarning, msg): ret = ObjectMapper.convert_dtype(spec, value) self.assertTupleEqual(ret, match) @@ -242,7 +343,7 @@ def _test_change_basetype_helper(self, spec_type, value_types, exp_type): def _test_convert_mismatch_helper(self, spec_type, value_types): data = 1 - spec = DatasetSpec('an example dataset', spec_type, name='data') + spec = DatasetSpec("an example dataset", spec_type, name="data") for dtype in value_types: value = self._get_type(dtype)(data) # convert data to given dtype with self.subTest(dtype=dtype): @@ -253,7 +354,7 @@ def _test_convert_mismatch_helper(self, spec_type, value_types): ObjectMapper.convert_dtype(spec, value) def test_dci_input(self): - spec = DatasetSpec('an example dataset', 'int64', name='data') + spec = DatasetSpec("an example dataset", "int64", name="data") value = DataChunkIterator(np.array([1, 2, 3], dtype=np.int32)) msg = "Spec 'data': Value with data type int32 is being converted to data type int64 as specified." with self.assertWarnsWith(UserWarning, msg): @@ -261,123 +362,123 @@ def test_dci_input(self): self.assertIs(ret, value) self.assertEqual(ret_dtype, np.int64) - spec = DatasetSpec('an example dataset', 'int16', name='data') + spec = DatasetSpec("an example dataset", "int16", name="data") value = DataChunkIterator(np.array([1, 2, 3], dtype=np.int32)) ret, ret_dtype = ObjectMapper.convert_dtype(spec, value) # no conversion self.assertIs(ret, value) self.assertEqual(ret_dtype, np.int32) # increase precision def test_text_spec(self): - text_spec_types = ['text', 'utf', 'utf8', 'utf-8'] + text_spec_types = ["text", "utf", "utf8", "utf-8"] for spec_type in text_spec_types: with self.subTest(spec_type=spec_type): - spec = DatasetSpec('an example dataset', spec_type, name='data') + spec = DatasetSpec("an example dataset", spec_type, name="data") - value = 'a' + value = "a" ret, ret_dtype = ObjectMapper.convert_dtype(spec, value) self.assertEqual(ret, value) self.assertIs(type(ret), str) - self.assertEqual(ret_dtype, 'utf8') + self.assertEqual(ret_dtype, "utf8") - value = b'a' + value = b"a" ret, ret_dtype = ObjectMapper.convert_dtype(spec, value) - self.assertEqual(ret, 'a') + self.assertEqual(ret, "a") self.assertIs(type(ret), str) - self.assertEqual(ret_dtype, 'utf8') + self.assertEqual(ret_dtype, "utf8") - value = ['a', 'b'] + value = ["a", "b"] ret, ret_dtype = ObjectMapper.convert_dtype(spec, value) self.assertListEqual(ret, value) self.assertIs(type(ret[0]), str) - self.assertEqual(ret_dtype, 'utf8') + self.assertEqual(ret_dtype, "utf8") - value = np.array(['a', 'b']) + value = np.array(["a", "b"]) ret, ret_dtype = ObjectMapper.convert_dtype(spec, value) np.testing.assert_array_equal(ret, value) - self.assertEqual(ret_dtype, 'utf8') + self.assertEqual(ret_dtype, "utf8") - value = np.array(['a', 'b'], dtype='S1') + value = np.array(["a", "b"], dtype="S1") ret, ret_dtype = ObjectMapper.convert_dtype(spec, value) - np.testing.assert_array_equal(ret, np.array(['a', 'b'], dtype='U1')) - self.assertEqual(ret_dtype, 'utf8') + np.testing.assert_array_equal(ret, np.array(["a", "b"], dtype="U1")) + self.assertEqual(ret_dtype, "utf8") value = [] ret, ret_dtype = ObjectMapper.convert_dtype(spec, value) self.assertListEqual(ret, value) - self.assertEqual(ret_dtype, 'utf8') + self.assertEqual(ret_dtype, "utf8") value = 1 msg = "Expected unicode or ascii string, got " with self.assertRaisesWith(ValueError, msg): ObjectMapper.convert_dtype(spec, value) - value = DataChunkIterator(np.array(['a', 'b'])) + value = DataChunkIterator(np.array(["a", "b"])) ret, ret_dtype = ObjectMapper.convert_dtype(spec, value) # no conversion self.assertIs(ret, value) - self.assertEqual(ret_dtype, 'utf8') + self.assertEqual(ret_dtype, "utf8") - value = DataChunkIterator(np.array(['a', 'b'], dtype='S1')) + value = DataChunkIterator(np.array(["a", "b"], dtype="S1")) ret, ret_dtype = ObjectMapper.convert_dtype(spec, value) # no conversion self.assertIs(ret, value) - self.assertEqual(ret_dtype, 'utf8') + self.assertEqual(ret_dtype, "utf8") def test_ascii_spec(self): - ascii_spec_types = ['ascii', 'bytes'] + ascii_spec_types = ["ascii", "bytes"] for spec_type in ascii_spec_types: with self.subTest(spec_type=spec_type): - spec = DatasetSpec('an example dataset', spec_type, name='data') + spec = DatasetSpec("an example dataset", spec_type, name="data") - value = 'a' + value = "a" ret, ret_dtype = ObjectMapper.convert_dtype(spec, value) - self.assertEqual(ret, b'a') + self.assertEqual(ret, b"a") self.assertIs(type(ret), bytes) - self.assertEqual(ret_dtype, 'ascii') + self.assertEqual(ret_dtype, "ascii") - value = b'a' + value = b"a" ret, ret_dtype = ObjectMapper.convert_dtype(spec, value) - self.assertEqual(ret, b'a') + self.assertEqual(ret, b"a") self.assertIs(type(ret), bytes) - self.assertEqual(ret_dtype, 'ascii') + self.assertEqual(ret_dtype, "ascii") - value = ['a', 'b'] + value = ["a", "b"] ret, ret_dtype = ObjectMapper.convert_dtype(spec, value) - self.assertListEqual(ret, [b'a', b'b']) + self.assertListEqual(ret, [b"a", b"b"]) self.assertIs(type(ret[0]), bytes) - self.assertEqual(ret_dtype, 'ascii') + self.assertEqual(ret_dtype, "ascii") - value = np.array(['a', 'b']) + value = np.array(["a", "b"]) ret, ret_dtype = ObjectMapper.convert_dtype(spec, value) - np.testing.assert_array_equal(ret, np.array(['a', 'b'], dtype='S1')) - self.assertEqual(ret_dtype, 'ascii') + np.testing.assert_array_equal(ret, np.array(["a", "b"], dtype="S1")) + self.assertEqual(ret_dtype, "ascii") - value = np.array(['a', 'b'], dtype='S1') + value = np.array(["a", "b"], dtype="S1") ret, ret_dtype = ObjectMapper.convert_dtype(spec, value) np.testing.assert_array_equal(ret, value) - self.assertEqual(ret_dtype, 'ascii') + self.assertEqual(ret_dtype, "ascii") value = [] ret, ret_dtype = ObjectMapper.convert_dtype(spec, value) self.assertListEqual(ret, value) - self.assertEqual(ret_dtype, 'ascii') + self.assertEqual(ret_dtype, "ascii") value = 1 msg = "Expected unicode or ascii string, got " with self.assertRaisesWith(ValueError, msg): ObjectMapper.convert_dtype(spec, value) - value = DataChunkIterator(np.array(['a', 'b'])) + value = DataChunkIterator(np.array(["a", "b"])) ret, ret_dtype = ObjectMapper.convert_dtype(spec, value) # no conversion self.assertIs(ret, value) - self.assertEqual(ret_dtype, 'ascii') + self.assertEqual(ret_dtype, "ascii") - value = DataChunkIterator(np.array(['a', 'b'], dtype='S1')) + value = DataChunkIterator(np.array(["a", "b"], dtype="S1")) ret, ret_dtype = ObjectMapper.convert_dtype(spec, value) # no conversion self.assertIs(ret, value) - self.assertEqual(ret_dtype, 'ascii') + self.assertEqual(ret_dtype, "ascii") def test_no_spec(self): spec_type = None - spec = DatasetSpec('an example dataset', spec_type, name='data') + spec = DatasetSpec("an example dataset", spec_type, name="data") value = [1, 2, 3] ret, ret_dtype = ObjectMapper.convert_dtype(spec, value) @@ -391,27 +492,27 @@ def test_no_spec(self): self.assertIs(type(ret), np.uint64) self.assertEqual(ret_dtype, np.uint64) - value = 'hello' + value = "hello" ret, ret_dtype = ObjectMapper.convert_dtype(spec, value) self.assertEqual(ret, value) self.assertIs(type(ret), str) - self.assertEqual(ret_dtype, 'utf8') + self.assertEqual(ret_dtype, "utf8") - value = b'hello' + value = b"hello" ret, ret_dtype = ObjectMapper.convert_dtype(spec, value) self.assertEqual(ret, value) self.assertIs(type(ret), bytes) - self.assertEqual(ret_dtype, 'ascii') + self.assertEqual(ret_dtype, "ascii") - value = np.array(['aa', 'bb']) + value = np.array(["aa", "bb"]) ret, ret_dtype = ObjectMapper.convert_dtype(spec, value) np.testing.assert_array_equal(ret, value) - self.assertEqual(ret_dtype, 'utf8') + self.assertEqual(ret_dtype, "utf8") - value = np.array(['aa', 'bb'], dtype='S2') + value = np.array(["aa", "bb"], dtype="S2") ret, ret_dtype = ObjectMapper.convert_dtype(spec, value) np.testing.assert_array_equal(ret, value) - self.assertEqual(ret_dtype, 'ascii') + self.assertEqual(ret_dtype, "ascii") value = DataChunkIterator(data=[1, 2, 3]) ret, ret_dtype = ObjectMapper.convert_dtype(spec, value) @@ -420,12 +521,12 @@ def test_no_spec(self): self.assertIs(type(ret.data[0]), int) self.assertEqual(ret_dtype, np.dtype(int).type) - value = DataChunkIterator(data=['a', 'b']) + value = DataChunkIterator(data=["a", "b"]) ret, ret_dtype = ObjectMapper.convert_dtype(spec, value) self.assertEqual(ret, value) self.assertIs(ret.dtype.type, np.str_) self.assertIs(type(ret.data[0]), str) - self.assertEqual(ret_dtype, 'utf8') + self.assertEqual(ret_dtype, "utf8") value = H5DataIO(np.arange(30).reshape(5, 2, 3)) ret, ret_dtype = ObjectMapper.convert_dtype(spec, value) @@ -433,17 +534,17 @@ def test_no_spec(self): self.assertIs(ret.data.dtype.type, np.dtype(int).type) self.assertEqual(ret_dtype, np.dtype(int).type) - value = H5DataIO(['foo', 'bar']) + value = H5DataIO(["foo", "bar"]) ret, ret_dtype = ObjectMapper.convert_dtype(spec, value) self.assertEqual(ret, value) self.assertIs(type(ret.data[0]), str) - self.assertEqual(ret_dtype, 'utf8') + self.assertEqual(ret_dtype, "utf8") - value = H5DataIO([b'foo', b'bar']) + value = H5DataIO([b"foo", b"bar"]) ret, ret_dtype = ObjectMapper.convert_dtype(spec, value) self.assertEqual(ret, value) self.assertIs(type(ret.data[0]), bytes) - self.assertEqual(ret_dtype, 'ascii') + self.assertEqual(ret_dtype, "ascii") value = [] msg = "Cannot infer dtype of empty list or tuple. Please use numpy array with specified dtype." @@ -451,8 +552,8 @@ def test_no_spec(self): ObjectMapper.convert_dtype(spec, value) def test_numeric_spec(self): - spec_type = 'numeric' - spec = DatasetSpec('an example dataset', spec_type, name='data') + spec_type = "numeric" + spec = DatasetSpec("an example dataset", spec_type, name="data") value = np.uint64(4) ret, ret_dtype = ObjectMapper.convert_dtype(spec, value) @@ -467,12 +568,12 @@ def test_numeric_spec(self): self.assertIs(type(ret.data[0]), int) self.assertEqual(ret_dtype, np.dtype(int).type) - value = ['a', 'b'] + value = ["a", "b"] msg = "Cannot convert from to 'numeric' specification dtype." with self.assertRaisesWith(ValueError, msg): ObjectMapper.convert_dtype(spec, value) - value = np.array(['a', 'b']) + value = np.array(["a", "b"]) msg = "Cannot convert from to 'numeric' specification dtype." with self.assertRaisesWith(ValueError, msg): ObjectMapper.convert_dtype(spec, value) @@ -483,8 +584,8 @@ def test_numeric_spec(self): ObjectMapper.convert_dtype(spec, value) def test_bool_spec(self): - spec_type = 'bool' - spec = DatasetSpec('an example dataset', spec_type, name='data') + spec_type = "bool" + spec = DatasetSpec("an example dataset", spec_type, name="data") value = np.bool_(True) ret, ret_dtype = ObjectMapper.convert_dtype(spec, value) @@ -499,43 +600,46 @@ def test_bool_spec(self): self.assertEqual(ret_dtype, np.bool_) def test_override_type_int_restrict_precision(self): - spec = DatasetSpec('an example dataset', 'int8', name='data') - res = ObjectMapper.convert_dtype(spec, np.int64(1), 'int64') + spec = DatasetSpec("an example dataset", "int8", name="data") + res = ObjectMapper.convert_dtype(spec, np.int64(1), "int64") self.assertTupleEqual(res, (np.int64(1), np.int64)) def test_override_type_numeric_to_uint(self): - spec = DatasetSpec('an example dataset', 'numeric', name='data') - res = ObjectMapper.convert_dtype(spec, np.uint32(1), 'uint8') + spec = DatasetSpec("an example dataset", "numeric", name="data") + res = ObjectMapper.convert_dtype(spec, np.uint32(1), "uint8") self.assertTupleEqual(res, (np.uint32(1), np.uint32)) def test_override_type_numeric_to_uint_list(self): - spec = DatasetSpec('an example dataset', 'numeric', name='data') - res = ObjectMapper.convert_dtype(spec, np.uint32((1, 2, 3)), 'uint8') + spec = DatasetSpec("an example dataset", "numeric", name="data") + res = ObjectMapper.convert_dtype(spec, np.uint32((1, 2, 3)), "uint8") np.testing.assert_array_equal(res[0], np.uint32((1, 2, 3))) self.assertEqual(res[1], np.uint32) def test_override_type_none_to_bool(self): - spec = DatasetSpec('an example dataset', None, name='data') - res = ObjectMapper.convert_dtype(spec, True, 'bool') + spec = DatasetSpec("an example dataset", None, name="data") + res = ObjectMapper.convert_dtype(spec, True, "bool") self.assertTupleEqual(res, (True, np.bool_)) def test_compound_type(self): """Test that convert_dtype passes through arguments if spec dtype is a list without any validation.""" - spec_type = [DtypeSpec('an int field', 'f1', 'int'), DtypeSpec('a float field', 'f2', 'float')] - spec = DatasetSpec('an example dataset', spec_type, name='data') - value = ['a', 1, 2.2] + spec_type = [ + DtypeSpec("an int field", "f1", "int"), + DtypeSpec("a float field", "f2", "float"), + ] + spec = DatasetSpec("an example dataset", spec_type, name="data") + value = ["a", 1, 2.2] res, ret_dtype = ObjectMapper.convert_dtype(spec, value) self.assertListEqual(res, value) self.assertListEqual(ret_dtype, spec_type) def test_isodatetime_spec(self): - spec_type = 'isodatetime' - spec = DatasetSpec('an example dataset', spec_type, name='data') + spec_type = "isodatetime" + spec = DatasetSpec("an example dataset", spec_type, name="data") # NOTE: datetime.isoformat is called on all values with a datetime spec before conversion # see ObjectMapper.get_attr_value value = datetime.isoformat(datetime(2020, 11, 10)) ret, ret_dtype = ObjectMapper.convert_dtype(spec, value) - self.assertEqual(ret, b'2020-11-10T00:00:00') + self.assertEqual(ret, b"2020-11-10T00:00:00") self.assertIs(type(ret), bytes) - self.assertEqual(ret_dtype, 'ascii') + self.assertEqual(ret_dtype, "ascii") diff --git a/tests/unit/build_tests/test_io_manager.py b/tests/unit/build_tests/test_io_manager.py index 01421e218..38c963374 100644 --- a/tests/unit/build_tests/test_io_manager.py +++ b/tests/unit/build_tests/test_io_manager.py @@ -1,90 +1,103 @@ from abc import ABCMeta, abstractmethod -from hdmf.build import GroupBuilder, DatasetBuilder, ObjectMapper, BuildManager, TypeMap, ContainerConfigurationError -from hdmf.spec import GroupSpec, AttributeSpec, DatasetSpec, SpecCatalog, SpecNamespace, NamespaceCatalog +from hdmf.build import ( + BuildManager, + ContainerConfigurationError, + DatasetBuilder, + GroupBuilder, + ObjectMapper, + TypeMap, +) +from hdmf.spec import ( + AttributeSpec, + DatasetSpec, + GroupSpec, + NamespaceCatalog, + SpecCatalog, + SpecNamespace, +) from hdmf.spec.spec import ZERO_OR_MANY from hdmf.testing import TestCase -from tests.unit.helpers.utils import Foo, FooBucket, CORE_NAMESPACE +from ..helpers.utils import CORE_NAMESPACE, Foo, FooBucket class FooMapper(ObjectMapper): - """Maps nested 'attr2' attribute on dataset 'my_data' to Foo.attr2 in constructor and attribute map - """ + """Maps nested 'attr2' attribute on dataset 'my_data' to Foo.attr2 in constructor and attribute map""" def __init__(self, spec): super().__init__(spec) - my_data_spec = spec.get_dataset('my_data') - self.map_spec('attr2', my_data_spec.get_attribute('attr2')) + my_data_spec = spec.get_dataset("my_data") + self.map_spec("attr2", my_data_spec.get_attribute("attr2")) class TestBase(TestCase): - def setUp(self): self.foo_spec = GroupSpec( - doc='A test group specification with a data type', - data_type_def='Foo', + doc="A test group specification with a data type", + data_type_def="Foo", datasets=[ DatasetSpec( - doc='an example dataset', - dtype='int', - name='my_data', + doc="an example dataset", + dtype="int", + name="my_data", attributes=[ AttributeSpec( - name='attr2', - doc='an example integer attribute', - dtype='int' + name="attr2", + doc="an example integer attribute", + dtype="int", ) - ] + ], ) ], - attributes=[AttributeSpec('attr1', 'an example string attribute', 'text')] + attributes=[AttributeSpec("attr1", "an example string attribute", "text")], ) self.spec_catalog = SpecCatalog() - self.spec_catalog.register_spec(self.foo_spec, 'test.yaml') + self.spec_catalog.register_spec(self.foo_spec, "test.yaml") self.namespace = SpecNamespace( - 'a test namespace', + "a test namespace", CORE_NAMESPACE, - [{'source': 'test.yaml'}], - version='0.1.0', - catalog=self.spec_catalog) + [{"source": "test.yaml"}], + version="0.1.0", + catalog=self.spec_catalog, + ) self.namespace_catalog = NamespaceCatalog() self.namespace_catalog.add_namespace(CORE_NAMESPACE, self.namespace) self.type_map = TypeMap(self.namespace_catalog) - self.type_map.register_container_type(CORE_NAMESPACE, 'Foo', Foo) + self.type_map.register_container_type(CORE_NAMESPACE, "Foo", Foo) self.type_map.register_map(Foo, FooMapper) self.manager = BuildManager(self.type_map) class TestBuildManager(TestBase): - def test_build(self): - container_inst = Foo('my_foo', list(range(10)), 'value1', 10) + container_inst = Foo("my_foo", list(range(10)), "value1", 10) expected = GroupBuilder( - 'my_foo', - datasets={ - 'my_data': - DatasetBuilder( - 'my_data', - list(range(10)), - attributes={'attr2': 10})}, - attributes={'attr1': 'value1', 'namespace': CORE_NAMESPACE, 'data_type': 'Foo', - 'object_id': container_inst.object_id}) + "my_foo", + datasets={"my_data": DatasetBuilder("my_data", list(range(10)), attributes={"attr2": 10})}, + attributes={ + "attr1": "value1", + "namespace": CORE_NAMESPACE, + "data_type": "Foo", + "object_id": container_inst.object_id, + }, + ) builder1 = self.manager.build(container_inst) self.assertDictEqual(builder1, expected) def test_build_memoization(self): - container_inst = Foo('my_foo', list(range(10)), 'value1', 10) + container_inst = Foo("my_foo", list(range(10)), "value1", 10) expected = GroupBuilder( - 'my_foo', - datasets={ - 'my_data': DatasetBuilder( - 'my_data', - list(range(10)), - attributes={'attr2': 10})}, - attributes={'attr1': 'value1', 'namespace': CORE_NAMESPACE, 'data_type': 'Foo', - 'object_id': container_inst.object_id}) + "my_foo", + datasets={"my_data": DatasetBuilder("my_data", list(range(10)), attributes={"attr2": 10})}, + attributes={ + "attr1": "value1", + "namespace": CORE_NAMESPACE, + "data_type": "Foo", + "object_id": container_inst.object_id, + }, + ) builder1 = self.manager.build(container_inst) builder2 = self.manager.build(container_inst) self.assertDictEqual(builder1, expected) @@ -92,45 +105,47 @@ def test_build_memoization(self): def test_construct(self): builder = GroupBuilder( - 'my_foo', - datasets={ - 'my_data': DatasetBuilder( - 'my_data', - list(range(10)), - attributes={'attr2': 10})}, - attributes={'attr1': 'value1', 'namespace': CORE_NAMESPACE, 'data_type': 'Foo', - 'object_id': -1}) + "my_foo", + datasets={"my_data": DatasetBuilder("my_data", list(range(10)), attributes={"attr2": 10})}, + attributes={ + "attr1": "value1", + "namespace": CORE_NAMESPACE, + "data_type": "Foo", + "object_id": -1, + }, + ) container = self.manager.construct(builder) self.assertListEqual(container.my_data, list(range(10))) - self.assertEqual(container.attr1, 'value1') + self.assertEqual(container.attr1, "value1") self.assertEqual(container.attr2, 10) def test_construct_memoization(self): builder = GroupBuilder( - 'my_foo', datasets={'my_data': DatasetBuilder( - 'my_data', - list(range(10)), - attributes={'attr2': 10})}, - attributes={'attr1': 'value1', 'namespace': CORE_NAMESPACE, 'data_type': 'Foo', - 'object_id': -1}) + "my_foo", + datasets={"my_data": DatasetBuilder("my_data", list(range(10)), attributes={"attr2": 10})}, + attributes={ + "attr1": "value1", + "namespace": CORE_NAMESPACE, + "data_type": "Foo", + "object_id": -1, + }, + ) container1 = self.manager.construct(builder) container2 = self.manager.construct(builder) self.assertIs(container1, container2) def test_clear_cache(self): - container_inst = Foo('my_foo', list(range(10)), 'value1', 10) + container_inst = Foo("my_foo", list(range(10)), "value1", 10) builder1 = self.manager.build(container_inst) self.manager.clear_cache() builder2 = self.manager.build(container_inst) self.assertIsNot(builder1, builder2) builder = GroupBuilder( - 'my_foo', datasets={'my_data': DatasetBuilder( - 'my_data', - list(range(10)), - attributes={'attr2': 10})}, - attributes={'attr1': 'value1', 'namespace': CORE_NAMESPACE, 'data_type': 'Foo', - 'object_id': -1}) + "my_foo", + datasets={"my_data": DatasetBuilder("my_data", list(range(10)), attributes={"attr2": 10})}, + attributes={"attr1": "value1", "namespace": CORE_NAMESPACE, "data_type": "Foo", "object_id": -1}, + ) container1 = self.manager.construct(builder) self.manager.clear_cache() container2 = self.manager.construct(builder) @@ -138,50 +153,59 @@ def test_clear_cache(self): class NestedBaseMixin(metaclass=ABCMeta): - def setUp(self): super().setUp() - self.foo_bucket = FooBucket('test_foo_bucket', [ - Foo('my_foo1', list(range(10)), 'value1', 10), - Foo('my_foo2', list(range(10, 20)), 'value2', 20)]) + self.foo_bucket = FooBucket( + "test_foo_bucket", + [ + Foo("my_foo1", list(range(10)), "value1", 10), + Foo("my_foo2", list(range(10, 20)), "value2", 20), + ], + ) self.foo_builders = { - 'my_foo1': GroupBuilder('my_foo1', - datasets={'my_data': DatasetBuilder( - 'my_data', - list(range(10)), - attributes={'attr2': 10})}, - attributes={'attr1': 'value1', 'namespace': CORE_NAMESPACE, 'data_type': 'Foo', - 'object_id': self.foo_bucket.foos['my_foo1'].object_id}), - 'my_foo2': GroupBuilder('my_foo2', datasets={'my_data': - DatasetBuilder( - 'my_data', - list(range(10, 20)), - attributes={'attr2': 20})}, - attributes={'attr1': 'value2', 'namespace': CORE_NAMESPACE, 'data_type': 'Foo', - 'object_id': self.foo_bucket.foos['my_foo2'].object_id}) + "my_foo1": GroupBuilder( + "my_foo1", + datasets={"my_data": DatasetBuilder("my_data", list(range(10)), attributes={"attr2": 10})}, + attributes={ + "attr1": "value1", + "namespace": CORE_NAMESPACE, + "data_type": "Foo", + "object_id": self.foo_bucket.foos["my_foo1"].object_id, + }, + ), + "my_foo2": GroupBuilder( + "my_foo2", + datasets={"my_data": DatasetBuilder("my_data", list(range(10, 20)), attributes={"attr2": 20})}, + attributes={ + "attr1": "value2", + "namespace": CORE_NAMESPACE, + "data_type": "Foo", + "object_id": self.foo_bucket.foos["my_foo2"].object_id, + }, + ), } self.setUpBucketBuilder() self.setUpBucketSpec() - self.spec_catalog.register_spec(self.bucket_spec, 'test.yaml') - self.type_map.register_container_type(CORE_NAMESPACE, 'FooBucket', FooBucket) + self.spec_catalog.register_spec(self.bucket_spec, "test.yaml") + self.type_map.register_container_type(CORE_NAMESPACE, "FooBucket", FooBucket) self.type_map.register_map(FooBucket, self.setUpBucketMapper()) self.manager = BuildManager(self.type_map) @abstractmethod def setUpBucketBuilder(self): - raise NotImplementedError('Cannot run test unless setUpBucketBuilder is implemented') + raise NotImplementedError("Cannot run test unless setUpBucketBuilder is implemented") @abstractmethod def setUpBucketSpec(self): - raise NotImplementedError('Cannot run test unless setUpBucketSpec is implemented') + raise NotImplementedError("Cannot run test unless setUpBucketSpec is implemented") @abstractmethod def setUpBucketMapper(self): - raise NotImplementedError('Cannot run test unless setUpBucketMapper is implemented') + raise NotImplementedError("Cannot run test unless setUpBucketMapper is implemented") def test_build(self): - ''' Test default mapping for an Container that has an Container as an attribute value ''' + """Test default mapping for an Container that has an Container as an attribute value""" builder = self.manager.build(self.foo_bucket) self.assertDictEqual(builder, self.bucket_builder) @@ -191,105 +215,147 @@ def test_construct(self): class TestNestedContainersNoSubgroups(NestedBaseMixin, TestBase): - ''' - Test BuildManager.build and BuildManager.construct when the - Container contains other Containers, but does not keep them in - additional subgroups - ''' + """ + Test BuildManager.build and BuildManager.construct when the + Container contains other Containers, but does not keep them in + additional subgroups + """ def setUpBucketBuilder(self): self.bucket_builder = GroupBuilder( - 'test_foo_bucket', + "test_foo_bucket", groups=self.foo_builders, - attributes={'namespace': CORE_NAMESPACE, 'data_type': 'FooBucket', 'object_id': self.foo_bucket.object_id}) + attributes={ + "namespace": CORE_NAMESPACE, + "data_type": "FooBucket", + "object_id": self.foo_bucket.object_id, + }, + ) def setUpBucketSpec(self): - self.bucket_spec = GroupSpec('A test group specification for a data type containing data type', - name="test_foo_bucket", - data_type_def='FooBucket', - groups=[GroupSpec( - 'the Foos in this bucket', - data_type_inc='Foo', - quantity=ZERO_OR_MANY)]) + self.bucket_spec = GroupSpec( + "A test group specification for a data type containing data type", + name="test_foo_bucket", + data_type_def="FooBucket", + groups=[ + GroupSpec( + "the Foos in this bucket", + data_type_inc="Foo", + quantity=ZERO_OR_MANY, + ) + ], + ) def setUpBucketMapper(self): return ObjectMapper class TestNestedContainersSubgroup(NestedBaseMixin, TestBase): - ''' - Test BuildManager.build and BuildManager.construct when the - Container contains other Containers that are stored in a subgroup - ''' + """ + Test BuildManager.build and BuildManager.construct when the + Container contains other Containers that are stored in a subgroup + """ def setUpBucketBuilder(self): - tmp_builder = GroupBuilder('foo_holder', groups=self.foo_builders) + tmp_builder = GroupBuilder("foo_holder", groups=self.foo_builders) self.bucket_builder = GroupBuilder( - 'test_foo_bucket', - groups={'foos': tmp_builder}, - attributes={'namespace': CORE_NAMESPACE, 'data_type': 'FooBucket', 'object_id': self.foo_bucket.object_id}) + "test_foo_bucket", + groups={"foos": tmp_builder}, + attributes={ + "namespace": CORE_NAMESPACE, + "data_type": "FooBucket", + "object_id": self.foo_bucket.object_id, + }, + ) def setUpBucketSpec(self): tmp_spec = GroupSpec( - 'A subgroup for Foos', - name='foo_holder', - groups=[GroupSpec('the Foos in this bucket', - data_type_inc='Foo', - quantity=ZERO_OR_MANY)]) - self.bucket_spec = GroupSpec('A test group specification for a data type containing data type', - name="test_foo_bucket", - data_type_def='FooBucket', - groups=[tmp_spec]) + "A subgroup for Foos", + name="foo_holder", + groups=[ + GroupSpec( + "the Foos in this bucket", + data_type_inc="Foo", + quantity=ZERO_OR_MANY, + ) + ], + ) + self.bucket_spec = GroupSpec( + "A test group specification for a data type containing data type", + name="test_foo_bucket", + data_type_def="FooBucket", + groups=[tmp_spec], + ) def setUpBucketMapper(self): class BucketMapper(ObjectMapper): def __init__(self, spec): super().__init__(spec) - self.unmap(spec.get_group('foo_holder')) - self.map_spec('foos', spec.get_group('foo_holder').get_data_type('Foo')) + self.unmap(spec.get_group("foo_holder")) + self.map_spec("foos", spec.get_group("foo_holder").get_data_type("Foo")) return BucketMapper class TestNestedContainersSubgroupSubgroup(NestedBaseMixin, TestBase): - ''' - Test BuildManager.build and BuildManager.construct when the - Container contains other Containers that are stored in a subgroup - in a subgroup - ''' + """ + Test BuildManager.build and BuildManager.construct when the + Container contains other Containers that are stored in a subgroup + in a subgroup + """ def setUpBucketBuilder(self): - tmp_builder = GroupBuilder('foo_holder', groups=self.foo_builders) - tmp_builder = GroupBuilder('foo_holder_holder', groups={'foo_holder': tmp_builder}) + tmp_builder = GroupBuilder("foo_holder", groups=self.foo_builders) + tmp_builder = GroupBuilder("foo_holder_holder", groups={"foo_holder": tmp_builder}) self.bucket_builder = GroupBuilder( - 'test_foo_bucket', - groups={'foo_holder': tmp_builder}, - attributes={'namespace': CORE_NAMESPACE, 'data_type': 'FooBucket', 'object_id': self.foo_bucket.object_id}) + "test_foo_bucket", + groups={"foo_holder": tmp_builder}, + attributes={ + "namespace": CORE_NAMESPACE, + "data_type": "FooBucket", + "object_id": self.foo_bucket.object_id, + }, + ) def setUpBucketSpec(self): - tmp_spec = GroupSpec('A subgroup for Foos', - name='foo_holder', - groups=[GroupSpec('the Foos in this bucket', - data_type_inc='Foo', - quantity=ZERO_OR_MANY)]) - tmp_spec = GroupSpec('A subgroup to hold the subgroup', name='foo_holder_holder', groups=[tmp_spec]) - self.bucket_spec = GroupSpec('A test group specification for a data type containing data type', - name="test_foo_bucket", - data_type_def='FooBucket', - groups=[tmp_spec]) + tmp_spec = GroupSpec( + "A subgroup for Foos", + name="foo_holder", + groups=[ + GroupSpec( + "the Foos in this bucket", + data_type_inc="Foo", + quantity=ZERO_OR_MANY, + ) + ], + ) + tmp_spec = GroupSpec( + "A subgroup to hold the subgroup", + name="foo_holder_holder", + groups=[tmp_spec], + ) + self.bucket_spec = GroupSpec( + "A test group specification for a data type containing data type", + name="test_foo_bucket", + data_type_def="FooBucket", + groups=[tmp_spec], + ) def setUpBucketMapper(self): class BucketMapper(ObjectMapper): def __init__(self, spec): super().__init__(spec) - self.unmap(spec.get_group('foo_holder_holder')) - self.unmap(spec.get_group('foo_holder_holder').get_group('foo_holder')) - self.map_spec('foos', spec.get_group('foo_holder_holder').get_group('foo_holder').get_data_type('Foo')) + self.unmap(spec.get_group("foo_holder_holder")) + self.unmap(spec.get_group("foo_holder_holder").get_group("foo_holder")) + self.map_spec( + "foos", + spec.get_group("foo_holder_holder").get_group("foo_holder").get_data_type("Foo"), + ) return BucketMapper def test_build(self): - ''' Test default mapping for an Container that has an Container as an attribute value ''' + """Test default mapping for an Container that has an Container as an attribute value""" builder = self.manager.build(self.foo_bucket) self.assertDictEqual(builder, self.bucket_builder) @@ -299,43 +365,49 @@ def test_construct(self): class TestNoAttribute(TestBase): - def test_build(self): """Test that an error is raised when a spec is mapped to a non-existent container attribute.""" + class Unmapper(ObjectMapper): def __init__(self, spec): super().__init__(spec) - self.map_spec("unknown", self.spec.get_dataset('my_data')) + self.map_spec("unknown", self.spec.get_dataset("my_data")) self.type_map.register_map(Foo, Unmapper) # override - container_inst = Foo('my_foo', list(range(10)), 'value1', 10) - msg = ("Foo 'my_foo' does not have attribute 'unknown' for mapping to spec: %s" - % self.foo_spec.get_dataset('my_data')) + container_inst = Foo("my_foo", list(range(10)), "value1", 10) + msg = "Foo 'my_foo' does not have attribute 'unknown' for mapping to spec: %s" % self.foo_spec.get_dataset( + "my_data" + ) with self.assertRaisesWith(ContainerConfigurationError, msg): self.manager.build(container_inst) class TestTypeMap(TestBase): - def test_get_ns_dt_missing(self): - bldr = GroupBuilder('my_foo', attributes={'attr1': 'value1'}) + bldr = GroupBuilder("my_foo", attributes={"attr1": "value1"}) dt = self.type_map.get_builder_dt(bldr) ns = self.type_map.get_builder_ns(bldr) self.assertIsNone(dt) self.assertIsNone(ns) def test_get_ns_dt(self): - bldr = GroupBuilder('my_foo', attributes={'attr1': 'value1', 'namespace': 'CORE', 'data_type': 'Foo', - 'object_id': -1}) + bldr = GroupBuilder( + "my_foo", + attributes={ + "attr1": "value1", + "namespace": "CORE", + "data_type": "Foo", + "object_id": -1, + }, + ) dt = self.type_map.get_builder_dt(bldr) ns = self.type_map.get_builder_ns(bldr) - self.assertEqual(dt, 'Foo') - self.assertEqual(ns, 'CORE') + self.assertEqual(dt, "Foo") + self.assertEqual(ns, "CORE") class TestRetrieveContainerClass(TestBase): - def test_get_dt_container_cls(self): ret = self.type_map.get_dt_container_cls(data_type="Foo") self.assertIs(ret, Foo) diff --git a/tests/unit/build_tests/test_io_map.py b/tests/unit/build_tests/test_io_map.py index 63f397682..6622b73d9 100644 --- a/tests/unit/build_tests/test_io_map.py +++ b/tests/unit/build_tests/test_io_map.py @@ -1,28 +1,47 @@ -from hdmf.utils import docval, getargs +import unittest +from abc import ABCMeta, abstractmethod + from hdmf import Container, Data from hdmf.backends.hdf5 import H5DataIO -from hdmf.build import (GroupBuilder, DatasetBuilder, ObjectMapper, BuildManager, TypeMap, LinkBuilder, - ReferenceBuilder, MissingRequiredBuildWarning, OrphanContainerBuildError, - ContainerConfigurationError) -from hdmf.spec import (GroupSpec, AttributeSpec, DatasetSpec, SpecCatalog, SpecNamespace, NamespaceCatalog, RefSpec, - LinkSpec) +from hdmf.build import ( + BuildManager, + ContainerConfigurationError, + DatasetBuilder, + GroupBuilder, + LinkBuilder, + MissingRequiredBuildWarning, + ObjectMapper, + OrphanContainerBuildError, + ReferenceBuilder, + TypeMap, +) +from hdmf.spec import ( + AttributeSpec, + DatasetSpec, + GroupSpec, + LinkSpec, + NamespaceCatalog, + RefSpec, + SpecCatalog, + SpecNamespace, +) from hdmf.testing import TestCase -from abc import ABCMeta, abstractmethod -import unittest +from hdmf.utils import docval, getargs -from tests.unit.helpers.utils import CORE_NAMESPACE, create_test_type_map +from ..helpers.utils import CORE_NAMESPACE, create_test_type_map class Bar(Container): - - @docval({'name': 'name', 'type': str, 'doc': 'the name of this Bar'}, - {'name': 'data', 'type': ('data', 'array_data'), 'doc': 'some data'}, - {'name': 'attr1', 'type': str, 'doc': 'an attribute'}, - {'name': 'attr2', 'type': int, 'doc': 'another attribute'}, - {'name': 'attr3', 'type': float, 'doc': 'a third attribute', 'default': 3.14}, - {'name': 'foo', 'type': 'Foo', 'doc': 'a group', 'default': None}) + @docval( + {"name": "name", "type": str, "doc": "the name of this Bar"}, + {"name": "data", "type": ("data", "array_data"), "doc": "some data"}, + {"name": "attr1", "type": str, "doc": "an attribute"}, + {"name": "attr2", "type": int, "doc": "another attribute"}, + {"name": "attr3", "type": float, "doc": "a third attribute", "default": 3.14}, + {"name": "foo", "type": "Foo", "doc": "a group", "default": None}, + ) def __init__(self, **kwargs): - name, data, attr1, attr2, attr3, foo = getargs('name', 'data', 'attr1', 'attr2', 'attr3', 'foo', kwargs) + name, data, attr1, attr2, attr3, foo = getargs("name", "data", "attr1", "attr2", "attr3", "foo", kwargs) super().__init__(name=name) self.__data = data self.__attr1 = attr1 @@ -33,16 +52,16 @@ def __init__(self, **kwargs): self.__foo.parent = self def __eq__(self, other): - attrs = ('name', 'data', 'attr1', 'attr2', 'attr3', 'foo') + attrs = ("name", "data", "attr1", "attr2", "attr3", "foo") return all(getattr(self, a) == getattr(other, a) for a in attrs) def __str__(self): - attrs = ('name', 'data', 'attr1', 'attr2', 'attr3', 'foo') - return ','.join('%s=%s' % (a, getattr(self, a)) for a in attrs) + attrs = ("name", "data", "attr1", "attr2", "attr3", "foo") + return ",".join("%s=%s" % (a, getattr(self, a)) for a in attrs) @property def data_type(self): - return 'Bar' + return "Bar" @property def data(self): @@ -74,92 +93,89 @@ class SubBar(Bar): class Foo(Container): - @property def data_type(self): - return 'Foo' + return "Foo" class FooData(Data): - @property def data_type(self): - return 'FooData' + return "FooData" class TestGetSubSpec(TestCase): - def setUp(self): - self.bar_spec = GroupSpec(doc='A test group specification with a data type', data_type_def='Bar') + self.bar_spec = GroupSpec( + doc="A test group specification with a data type", + data_type_def="Bar", + ) self.sub_bar_spec = GroupSpec( - doc='A test group specification with a data type', - data_type_def='SubBar', - data_type_inc='Bar' + doc="A test group specification with a data type", + data_type_def="SubBar", + data_type_inc="Bar", ) - self.type_map = create_test_type_map([self.bar_spec, self.sub_bar_spec], {'Bar': Bar, 'SubBar': SubBar}) + self.type_map = create_test_type_map([self.bar_spec, self.sub_bar_spec], {"Bar": Bar, "SubBar": SubBar}) def test_bad_name(self): """Test get_subspec on a builder that doesn't map to the spec.""" - parent_spec = GroupSpec(doc='Empty group', name='bar_bucket') + parent_spec = GroupSpec(doc="Empty group", name="bar_bucket") sub_builder = GroupBuilder( - name='my_bar', + name="my_bar", attributes={ - 'data_type': 'Bar', - 'namespace': CORE_NAMESPACE, - 'object_id': -1 - } + "data_type": "Bar", + "namespace": CORE_NAMESPACE, + "object_id": -1, + }, ) - GroupBuilder(name='bar_bucket', groups={'my_bar': sub_builder}) # add sub_builder as a child to bar_bucket + GroupBuilder(name="bar_bucket", groups={"my_bar": sub_builder}) # add sub_builder as a child to bar_bucket result = self.type_map.get_subspec(parent_spec, sub_builder) self.assertIsNone(result) def test_bad_name_no_data_type(self): """Test get_subspec on a builder without a data type that doesn't map to the spec.""" - parent_spec = GroupSpec(doc='Empty group', name='bar_bucket') - sub_builder = GroupBuilder(name='my_bar') - GroupBuilder(name='bar_bucket', groups={'my_bar': sub_builder}) # add sub_builder as a child to bar_bucket + parent_spec = GroupSpec(doc="Empty group", name="bar_bucket") + sub_builder = GroupBuilder(name="my_bar") + GroupBuilder(name="bar_bucket", groups={"my_bar": sub_builder}) # add sub_builder as a child to bar_bucket result = self.type_map.get_subspec(parent_spec, sub_builder) self.assertIsNone(result) def test_unnamed_group_data_type_def(self): """Test get_subspec on a builder that maps to an unnamed subgroup of the given spec using data_type_def.""" parent_spec = GroupSpec( - doc='Something to hold a Bar', - name='bar_bucket', - groups=[self.bar_spec] # using data_type_def, no name + doc="Something to hold a Bar", + name="bar_bucket", + groups=[self.bar_spec], # using data_type_def, no name ) sub_builder = GroupBuilder( - name='my_bar', + name="my_bar", attributes={ - 'data_type': 'Bar', - 'namespace': CORE_NAMESPACE, - 'object_id': -1 - } + "data_type": "Bar", + "namespace": CORE_NAMESPACE, + "object_id": -1, + }, ) - GroupBuilder(name='bar_bucket', groups={'my_bar': sub_builder}) # add sub_builder as a child to bar_bucket + GroupBuilder(name="bar_bucket", groups={"my_bar": sub_builder}) # add sub_builder as a child to bar_bucket result = self.type_map.get_subspec(parent_spec, sub_builder) self.assertIs(result, self.bar_spec) def test_unnamed_group_data_type_inc(self): """Test get_subspec on a builder that maps to an unnamed subgroup of the given spec using data_type_inc.""" - inc_spec = GroupSpec( - doc='This Bar', - data_type_inc='Bar' - ) + inc_spec = GroupSpec(doc="This Bar", data_type_inc="Bar") parent_spec = GroupSpec( - doc='Something to hold a Bar', - name='bar_bucket', - groups=[inc_spec] # using data_type_inc + doc="Something to hold a Bar", + name="bar_bucket", + groups=[inc_spec], # using data_type_inc ) sub_builder = GroupBuilder( - name='my_bar', + name="my_bar", attributes={ - 'data_type': 'Bar', - 'namespace': CORE_NAMESPACE, - 'object_id': -1 - } + "data_type": "Bar", + "namespace": CORE_NAMESPACE, + "object_id": -1, + }, ) - GroupBuilder(name='bar_bucket', groups={'my_bar': sub_builder}) # add sub_builder as a child to bar_bucket + GroupBuilder(name="bar_bucket", groups={"my_bar": sub_builder}) # add sub_builder as a child to bar_bucket result = self.type_map.get_subspec(parent_spec, sub_builder) self.assertIs(result, inc_spec) @@ -167,17 +183,17 @@ def test_named_group(self): """Test get_subspec on a builder that maps to a named subgroup of the given spec.""" # NOTE this works despite the fact that child_spec has no data type but the builder has a data type because # get_subspec acts on the name and not necessarily the data type - child_spec = GroupSpec(doc='A test group specification', name='my_subgroup') - parent_spec = GroupSpec(doc='Something to hold a Bar', name='my_group', groups=[child_spec]) + child_spec = GroupSpec(doc="A test group specification", name="my_subgroup") + parent_spec = GroupSpec(doc="Something to hold a Bar", name="my_group", groups=[child_spec]) sub_builder = GroupBuilder( - name='my_subgroup', + name="my_subgroup", attributes={ - 'data_type': 'Bar', - 'namespace': CORE_NAMESPACE, - 'object_id': -1 - } + "data_type": "Bar", + "namespace": CORE_NAMESPACE, + "object_id": -1, + }, ) - GroupBuilder(name='my_group', groups={'my_bar': sub_builder}) # add sub_builder as a child to my_group + GroupBuilder(name="my_group", groups={"my_bar": sub_builder}) # add sub_builder as a child to my_group result = self.type_map.get_subspec(parent_spec, sub_builder) self.assertIs(result, child_spec) @@ -185,111 +201,106 @@ def test_named_dataset(self): """Test get_subspec on a builder that maps to a named dataset of the given spec.""" # NOTE this works despite the fact that child_spec has no data type but the builder has a data type because # get_subspec acts on the name and not necessarily the data type - child_spec = DatasetSpec(doc='A test dataset specification', name='my_dataset') - parent_spec = GroupSpec(doc='Something to hold a Bar', name='my_group', datasets=[child_spec]) + child_spec = DatasetSpec(doc="A test dataset specification", name="my_dataset") + parent_spec = GroupSpec( + doc="Something to hold a Bar", + name="my_group", + datasets=[child_spec], + ) sub_builder = DatasetBuilder( - name='my_dataset', + name="my_dataset", data=[], attributes={ - 'data_type': 'FooData', - 'namespace': CORE_NAMESPACE, - 'object_id': -1 - } + "data_type": "FooData", + "namespace": CORE_NAMESPACE, + "object_id": -1, + }, ) - GroupBuilder(name='my_group', datasets={'my_dataset': sub_builder}) # add sub_builder as a child to my_group + GroupBuilder(name="my_group", datasets={"my_dataset": sub_builder}) # add sub_builder as a child to my_group result = self.type_map.get_subspec(parent_spec, sub_builder) self.assertIs(result, child_spec) def test_unnamed_link_data_type_inc(self): """Test get_subspec on a builder that maps to an unnamed link.""" - link_spec = LinkSpec(doc='This Bar', target_type='Bar') - parent_spec = GroupSpec( - doc='Something to hold a Bar', - name='bar_bucket', - links=[link_spec] - ) + link_spec = LinkSpec(doc="This Bar", target_type="Bar") + parent_spec = GroupSpec(doc="Something to hold a Bar", name="bar_bucket", links=[link_spec]) bar_builder = GroupBuilder( - name='my_bar', + name="my_bar", attributes={ - 'data_type': 'Bar', - 'namespace': CORE_NAMESPACE, - 'object_id': -1 - } + "data_type": "Bar", + "namespace": CORE_NAMESPACE, + "object_id": -1, + }, ) - sub_builder = LinkBuilder(builder=bar_builder, name='my_link') - GroupBuilder(name='bar_bucket', links={'my_bar': sub_builder}) + sub_builder = LinkBuilder(builder=bar_builder, name="my_link") + GroupBuilder(name="bar_bucket", links={"my_bar": sub_builder}) result = self.type_map.get_subspec(parent_spec, sub_builder) self.assertIs(result, link_spec) def test_named_link_data_type_inc(self): """Test get_subspec on a builder that maps to an named link.""" - link_spec = LinkSpec(doc='This Bar', target_type='Bar', name='bar_link') - parent_spec = GroupSpec( - doc='Something to hold a Bar', - name='bar_bucket', - links=[link_spec] - ) + link_spec = LinkSpec(doc="This Bar", target_type="Bar", name="bar_link") + parent_spec = GroupSpec(doc="Something to hold a Bar", name="bar_bucket", links=[link_spec]) bar_builder = GroupBuilder( - name='my_bar', + name="my_bar", attributes={ - 'data_type': 'Bar', - 'namespace': CORE_NAMESPACE, - 'object_id': -1 - } + "data_type": "Bar", + "namespace": CORE_NAMESPACE, + "object_id": -1, + }, ) - sub_builder = LinkBuilder(builder=bar_builder, name='bar_link') - GroupBuilder(name='bar_bucket', links={'my_bar': sub_builder}) + sub_builder = LinkBuilder(builder=bar_builder, name="bar_link") + GroupBuilder(name="bar_bucket", links={"my_bar": sub_builder}) result = self.type_map.get_subspec(parent_spec, sub_builder) self.assertIs(result, link_spec) def test_named_link_hierarchy_data_type_inc(self): """Test get_subspec on a builder that maps to an named link.""" - link_spec = LinkSpec(doc='This Bar', target_type='Bar', name='bar_link') - parent_spec = GroupSpec( - doc='Something to hold a Bar', - name='bar_bucket', - links=[link_spec] - ) + link_spec = LinkSpec(doc="This Bar", target_type="Bar", name="bar_link") + parent_spec = GroupSpec(doc="Something to hold a Bar", name="bar_bucket", links=[link_spec]) bar_builder = GroupBuilder( - name='my_bar', + name="my_bar", attributes={ - 'data_type': 'SubBar', - 'namespace': CORE_NAMESPACE, - 'object_id': -1 - } + "data_type": "SubBar", + "namespace": CORE_NAMESPACE, + "object_id": -1, + }, ) - sub_builder = LinkBuilder(builder=bar_builder, name='bar_link') - GroupBuilder(name='bar_bucket', links={'my_bar': sub_builder}) + sub_builder = LinkBuilder(builder=bar_builder, name="bar_link") + GroupBuilder(name="bar_bucket", links={"my_bar": sub_builder}) result = self.type_map.get_subspec(parent_spec, sub_builder) self.assertIs(result, link_spec) class TestTypeMap(TestCase): - def setUp(self): - self.bar_spec = GroupSpec('A test group specification with a data type', data_type_def='Bar') - self.foo_spec = GroupSpec('A test group specification with data type Foo', data_type_def='Foo') + self.bar_spec = GroupSpec("A test group specification with a data type", data_type_def="Bar") + self.foo_spec = GroupSpec("A test group specification with data type Foo", data_type_def="Foo") self.spec_catalog = SpecCatalog() - self.spec_catalog.register_spec(self.bar_spec, 'test.yaml') - self.spec_catalog.register_spec(self.foo_spec, 'test.yaml') - self.namespace = SpecNamespace('a test namespace', CORE_NAMESPACE, [{'source': 'test.yaml'}], - version='0.1.0', - catalog=self.spec_catalog) + self.spec_catalog.register_spec(self.bar_spec, "test.yaml") + self.spec_catalog.register_spec(self.foo_spec, "test.yaml") + self.namespace = SpecNamespace( + "a test namespace", + CORE_NAMESPACE, + [{"source": "test.yaml"}], + version="0.1.0", + catalog=self.spec_catalog, + ) self.namespace_catalog = NamespaceCatalog() self.namespace_catalog.add_namespace(CORE_NAMESPACE, self.namespace) self.type_map = TypeMap(self.namespace_catalog) - self.type_map.register_container_type(CORE_NAMESPACE, 'Bar', Bar) - self.type_map.register_container_type(CORE_NAMESPACE, 'Foo', Foo) + self.type_map.register_container_type(CORE_NAMESPACE, "Bar", Bar) + self.type_map.register_container_type(CORE_NAMESPACE, "Foo", Foo) def test_get_map_unique_mappers(self): - bar_inst = Bar('my_bar', list(range(10)), 'value1', 10) - foo_inst = Foo(name='my_foo') + bar_inst = Bar("my_bar", list(range(10)), "value1", 10) + foo_inst = Foo(name="my_foo") bar_mapper = self.type_map.get_map(bar_inst) foo_mapper = self.type_map.get_map(foo_inst) self.assertIsNot(bar_mapper, foo_mapper) def test_get_map(self): - container_inst = Bar('my_bar', list(range(10)), 'value1', 10) + container_inst = Bar("my_bar", list(range(10)), "value1", 10) mapper = self.type_map.get_map(container_inst) self.assertIsInstance(mapper, ObjectMapper) self.assertIs(mapper.spec, self.bar_spec) @@ -299,9 +310,10 @@ def test_get_map(self): def test_get_map_register(self): class MyMap(ObjectMapper): pass + self.type_map.register_map(Bar, MyMap) - container_inst = Bar('my_bar', list(range(10)), 'value1', 10) + container_inst = Bar("my_bar", list(range(10)), "value1", 10) mapper = self.type_map.get_map(container_inst) self.assertIs(mapper.spec, self.bar_spec) self.assertIsInstance(mapper, MyMap) @@ -310,83 +322,112 @@ class MyMap(ObjectMapper): class BarMapper(ObjectMapper): def __init__(self, spec): super().__init__(spec) - data_spec = spec.get_dataset('data') - self.map_spec('attr2', data_spec.get_attribute('attr2')) + data_spec = spec.get_dataset("data") + self.map_spec("attr2", data_spec.get_attribute("attr2")) class TestMapStrings(TestCase): - def customSetUp(self, bar_spec): spec_catalog = SpecCatalog() - spec_catalog.register_spec(bar_spec, 'test.yaml') - namespace = SpecNamespace('a test namespace', CORE_NAMESPACE, [{'source': 'test.yaml'}], version='0.1.0', - catalog=spec_catalog) + spec_catalog.register_spec(bar_spec, "test.yaml") + namespace = SpecNamespace( + "a test namespace", + CORE_NAMESPACE, + [{"source": "test.yaml"}], + version="0.1.0", + catalog=spec_catalog, + ) namespace_catalog = NamespaceCatalog() namespace_catalog.add_namespace(CORE_NAMESPACE, namespace) type_map = TypeMap(namespace_catalog) - type_map.register_container_type(CORE_NAMESPACE, 'Bar', Bar) + type_map.register_container_type(CORE_NAMESPACE, "Bar", Bar) return type_map def test_build_1d(self): - bar_spec = GroupSpec('A test group specification with a data type', - data_type_def='Bar', - datasets=[DatasetSpec('an example dataset', 'text', name='data', shape=(None,), - attributes=[AttributeSpec( - 'attr2', 'an example integer attribute', 'int')])], - attributes=[AttributeSpec('attr1', 'an example string attribute', 'text')]) + bar_spec = GroupSpec( + "A test group specification with a data type", + data_type_def="Bar", + datasets=[ + DatasetSpec( + "an example dataset", + "text", + name="data", + shape=(None,), + attributes=[AttributeSpec("attr2", "an example integer attribute", "int")], + ) + ], + attributes=[AttributeSpec("attr1", "an example string attribute", "text")], + ) type_map = self.customSetUp(bar_spec) type_map.register_map(Bar, BarMapper) - bar_inst = Bar('my_bar', ['a', 'b', 'c', 'd'], 'value1', 10) + bar_inst = Bar("my_bar", ["a", "b", "c", "d"], "value1", 10) builder = type_map.build(bar_inst) - self.assertEqual(builder.get('data').data, ['a', 'b', 'c', 'd']) + self.assertEqual(builder.get("data").data, ["a", "b", "c", "d"]) def test_build_scalar(self): - bar_spec = GroupSpec('A test group specification with a data type', - data_type_def='Bar', - datasets=[DatasetSpec('an example dataset', 'text', name='data', - attributes=[AttributeSpec( - 'attr2', 'an example integer attribute', 'int')])], - attributes=[AttributeSpec('attr1', 'an example string attribute', 'text')]) + bar_spec = GroupSpec( + "A test group specification with a data type", + data_type_def="Bar", + datasets=[ + DatasetSpec( + "an example dataset", + "text", + name="data", + attributes=[AttributeSpec("attr2", "an example integer attribute", "int")], + ) + ], + attributes=[AttributeSpec("attr1", "an example string attribute", "text")], + ) type_map = self.customSetUp(bar_spec) type_map.register_map(Bar, BarMapper) - bar_inst = Bar('my_bar', ['a', 'b', 'c', 'd'], 'value1', 10) + bar_inst = Bar("my_bar", ["a", "b", "c", "d"], "value1", 10) builder = type_map.build(bar_inst) - self.assertEqual(builder.get('data').data, "['a', 'b', 'c', 'd']") + self.assertEqual(builder.get("data").data, "['a', 'b', 'c', 'd']") def test_build_dataio(self): - bar_spec = GroupSpec('A test group specification with a data type', - data_type_def='Bar', - datasets=[DatasetSpec('an example dataset', 'text', name='data', shape=(None,), - attributes=[AttributeSpec( - 'attr2', 'an example integer attribute', 'int')])], - attributes=[AttributeSpec('attr1', 'an example string attribute', 'text')]) + bar_spec = GroupSpec( + "A test group specification with a data type", + data_type_def="Bar", + datasets=[ + DatasetSpec( + "an example dataset", + "text", + name="data", + shape=(None,), + attributes=[AttributeSpec("attr2", "an example integer attribute", "int")], + ) + ], + attributes=[AttributeSpec("attr1", "an example string attribute", "text")], + ) type_map = self.customSetUp(bar_spec) type_map.register_map(Bar, BarMapper) - bar_inst = Bar('my_bar', H5DataIO(['a', 'b', 'c', 'd'], chunks=True), 'value1', 10) + bar_inst = Bar("my_bar", H5DataIO(["a", "b", "c", "d"], chunks=True), "value1", 10) builder = type_map.build(bar_inst) - self.assertIsInstance(builder.get('data').data, H5DataIO) + self.assertIsInstance(builder.get("data").data, H5DataIO) class ObjectMapperMixin(metaclass=ABCMeta): - def setUp(self): self.setUpBarSpec() self.spec_catalog = SpecCatalog() - self.spec_catalog.register_spec(self.bar_spec, 'test.yaml') - self.namespace = SpecNamespace('a test namespace', CORE_NAMESPACE, - [{'source': 'test.yaml'}], - version='0.1.0', - catalog=self.spec_catalog) + self.spec_catalog.register_spec(self.bar_spec, "test.yaml") + self.namespace = SpecNamespace( + "a test namespace", + CORE_NAMESPACE, + [{"source": "test.yaml"}], + version="0.1.0", + catalog=self.spec_catalog, + ) self.namespace_catalog = NamespaceCatalog() self.namespace_catalog.add_namespace(CORE_NAMESPACE, self.namespace) self.type_map = TypeMap(self.namespace_catalog) - self.type_map.register_container_type(CORE_NAMESPACE, 'Bar', Bar) + self.type_map.register_container_type(CORE_NAMESPACE, "Bar", Bar) self.manager = BuildManager(self.type_map) self.mapper = ObjectMapper(self.bar_spec) @abstractmethod def setUpBarSpec(self): - raise NotImplementedError('Cannot run test unless setUpBarSpec is implemented') + raise NotImplementedError("Cannot run test unless setUpBarSpec is implemented") def test_default_mapping(self): attr_map = self.mapper.get_attr_names(self.bar_spec) @@ -398,47 +439,47 @@ def test_default_mapping(self): class TestObjectMapperNested(ObjectMapperMixin, TestCase): - def setUpBarSpec(self): - self.bar_spec = GroupSpec('A test group specification with a data type', - data_type_def='Bar', - datasets=[DatasetSpec('an example dataset', 'int', name='data', - attributes=[AttributeSpec( - 'attr2', 'an example integer attribute', 'int')])], - attributes=[AttributeSpec('attr1', 'an example string attribute', 'text')]) + self.bar_spec = GroupSpec( + "A test group specification with a data type", + data_type_def="Bar", + datasets=[ + DatasetSpec( + "an example dataset", + "int", + name="data", + attributes=[AttributeSpec("attr2", "an example integer attribute", "int")], + ) + ], + attributes=[AttributeSpec("attr1", "an example string attribute", "text")], + ) def test_build(self): - ''' Test default mapping functionality when object attributes map to an attribute deeper - than top-level Builder ''' - container_inst = Bar('my_bar', list(range(10)), 'value1', 10) + """Test default mapping functionality when object attributes map to an attribute deeper + than top-level Builder""" + container_inst = Bar("my_bar", list(range(10)), "value1", 10) expected = GroupBuilder( - name='my_bar', - datasets={'data': DatasetBuilder( - name='data', - data=list(range(10)), - attributes={'attr2': 10} - )}, - attributes={'attr1': 'value1'} + name="my_bar", + datasets={"data": DatasetBuilder(name="data", data=list(range(10)), attributes={"attr2": 10})}, + attributes={"attr1": "value1"}, ) self._remap_nested_attr() builder = self.mapper.build(container_inst, self.manager) self.assertBuilderEqual(builder, expected) def test_construct(self): - ''' Test default mapping functionality when object attributes map to an attribute - deeper than top-level Builder ''' - expected = Bar('my_bar', list(range(10)), 'value1', 10) + """Test default mapping functionality when object attributes map to an attribute + deeper than top-level Builder""" + expected = Bar("my_bar", list(range(10)), "value1", 10) builder = GroupBuilder( - name='my_bar', - datasets={'data': DatasetBuilder( - name='data', - data=list(range(10)), - attributes={'attr2': 10} - )}, - attributes={'attr1': 'value1', - 'data_type': 'Bar', - 'namespace': CORE_NAMESPACE, - 'object_id': expected.object_id} + name="my_bar", + datasets={"data": DatasetBuilder(name="data", data=list(range(10)), attributes={"attr2": 10})}, + attributes={ + "attr1": "value1", + "data_type": "Bar", + "namespace": CORE_NAMESPACE, + "object_id": expected.object_id, + }, ) self._remap_nested_attr() container = self.mapper.construct(builder, self.manager) @@ -447,226 +488,290 @@ def test_construct(self): def test_default_mapping_keys(self): attr_map = self.mapper.get_attr_names(self.bar_spec) keys = set(attr_map.keys()) - expected = {'attr1', 'data', 'data__attr2'} + expected = {"attr1", "data", "data__attr2"} self.assertSetEqual(keys, expected) def test_remap_keys(self): self._remap_nested_attr() - self.assertEqual(self.mapper.get_attr_spec('attr2'), - self.mapper.spec.get_dataset('data').get_attribute('attr2')) - self.assertEqual(self.mapper.get_attr_spec('attr1'), self.mapper.spec.get_attribute('attr1')) - self.assertEqual(self.mapper.get_attr_spec('data'), self.mapper.spec.get_dataset('data')) + self.assertEqual( + self.mapper.get_attr_spec("attr2"), + self.mapper.spec.get_dataset("data").get_attribute("attr2"), + ) + self.assertEqual( + self.mapper.get_attr_spec("attr1"), + self.mapper.spec.get_attribute("attr1"), + ) + self.assertEqual( + self.mapper.get_attr_spec("data"), + self.mapper.spec.get_dataset("data"), + ) def _remap_nested_attr(self): - data_spec = self.mapper.spec.get_dataset('data') - self.mapper.map_spec('attr2', data_spec.get_attribute('attr2')) + data_spec = self.mapper.spec.get_dataset("data") + self.mapper.map_spec("attr2", data_spec.get_attribute("attr2")) class TestObjectMapperNoNesting(ObjectMapperMixin, TestCase): - def setUpBarSpec(self): - self.bar_spec = GroupSpec('A test group specification with a data type', - data_type_def='Bar', - datasets=[DatasetSpec('an example dataset', 'int', name='data')], - attributes=[AttributeSpec('attr1', 'an example string attribute', 'text'), - AttributeSpec('attr2', 'an example integer attribute', 'int')]) + self.bar_spec = GroupSpec( + "A test group specification with a data type", + data_type_def="Bar", + datasets=[DatasetSpec("an example dataset", "int", name="data")], + attributes=[ + AttributeSpec("attr1", "an example string attribute", "text"), + AttributeSpec("attr2", "an example integer attribute", "int"), + ], + ) def test_build(self): - ''' Test default mapping functionality when no attributes are nested ''' - container = Bar('my_bar', list(range(10)), 'value1', 10) + """Test default mapping functionality when no attributes are nested""" + container = Bar("my_bar", list(range(10)), "value1", 10) builder = self.mapper.build(container, self.manager) - expected = GroupBuilder('my_bar', datasets={'data': DatasetBuilder('data', list(range(10)))}, - attributes={'attr1': 'value1', 'attr2': 10}) + expected = GroupBuilder( + "my_bar", + datasets={"data": DatasetBuilder("data", list(range(10)))}, + attributes={"attr1": "value1", "attr2": 10}, + ) self.assertBuilderEqual(builder, expected) def test_build_empty(self): - ''' Test default mapping functionality when no attributes are nested ''' - container = Bar('my_bar', [], 'value1', 10) + """Test default mapping functionality when no attributes are nested""" + container = Bar("my_bar", [], "value1", 10) builder = self.mapper.build(container, self.manager) - expected = GroupBuilder('my_bar', datasets={'data': DatasetBuilder('data', [])}, - attributes={'attr1': 'value1', 'attr2': 10}) + expected = GroupBuilder( + "my_bar", + datasets={"data": DatasetBuilder("data", [])}, + attributes={"attr1": "value1", "attr2": 10}, + ) self.assertBuilderEqual(builder, expected) def test_construct(self): - expected = Bar('my_bar', list(range(10)), 'value1', 10) - builder = GroupBuilder('my_bar', datasets={'data': DatasetBuilder('data', list(range(10)))}, - attributes={'attr1': 'value1', 'attr2': 10, 'data_type': 'Bar', - 'namespace': CORE_NAMESPACE, 'object_id': expected.object_id}) + expected = Bar("my_bar", list(range(10)), "value1", 10) + builder = GroupBuilder( + "my_bar", + datasets={"data": DatasetBuilder("data", list(range(10)))}, + attributes={ + "attr1": "value1", + "attr2": 10, + "data_type": "Bar", + "namespace": CORE_NAMESPACE, + "object_id": expected.object_id, + }, + ) container = self.mapper.construct(builder, self.manager) self.assertEqual(container, expected) def test_default_mapping_keys(self): attr_map = self.mapper.get_attr_names(self.bar_spec) keys = set(attr_map.keys()) - expected = {'attr1', 'data', 'attr2'} + expected = {"attr1", "data", "attr2"} self.assertSetEqual(keys, expected) class TestObjectMapperContainer(ObjectMapperMixin, TestCase): - def setUpBarSpec(self): - self.bar_spec = GroupSpec('A test group specification with a data type', - data_type_def='Bar', - groups=[GroupSpec('an example group', data_type_def='Foo')], - attributes=[AttributeSpec('attr1', 'an example string attribute', 'text'), - AttributeSpec('attr2', 'an example integer attribute', 'int')]) + self.bar_spec = GroupSpec( + "A test group specification with a data type", + data_type_def="Bar", + groups=[GroupSpec("an example group", data_type_def="Foo")], + attributes=[ + AttributeSpec("attr1", "an example string attribute", "text"), + AttributeSpec("attr2", "an example integer attribute", "int"), + ], + ) def test_default_mapping_keys(self): attr_map = self.mapper.get_attr_names(self.bar_spec) keys = set(attr_map.keys()) - expected = {'attr1', 'foo', 'attr2'} + expected = {"attr1", "foo", "attr2"} self.assertSetEqual(keys, expected) class TestLinkedContainer(TestCase): - def setUp(self): - self.foo_spec = GroupSpec('A test group specification with data type Foo', data_type_def='Foo') - self.bar_spec = GroupSpec('A test group specification with a data type Bar', - data_type_def='Bar', - groups=[self.foo_spec], - datasets=[DatasetSpec('an example dataset', 'int', name='data')], - attributes=[AttributeSpec('attr1', 'an example string attribute', 'text'), - AttributeSpec('attr2', 'an example integer attribute', 'int')]) + self.foo_spec = GroupSpec("A test group specification with data type Foo", data_type_def="Foo") + self.bar_spec = GroupSpec( + "A test group specification with a data type Bar", + data_type_def="Bar", + groups=[self.foo_spec], + datasets=[DatasetSpec("an example dataset", "int", name="data")], + attributes=[ + AttributeSpec("attr1", "an example string attribute", "text"), + AttributeSpec("attr2", "an example integer attribute", "int"), + ], + ) self.spec_catalog = SpecCatalog() - self.spec_catalog.register_spec(self.foo_spec, 'test.yaml') - self.spec_catalog.register_spec(self.bar_spec, 'test.yaml') - self.namespace = SpecNamespace('a test namespace', CORE_NAMESPACE, - [{'source': 'test.yaml'}], - version='0.1.0', - catalog=self.spec_catalog) + self.spec_catalog.register_spec(self.foo_spec, "test.yaml") + self.spec_catalog.register_spec(self.bar_spec, "test.yaml") + self.namespace = SpecNamespace( + "a test namespace", + CORE_NAMESPACE, + [{"source": "test.yaml"}], + version="0.1.0", + catalog=self.spec_catalog, + ) self.namespace_catalog = NamespaceCatalog() self.namespace_catalog.add_namespace(CORE_NAMESPACE, self.namespace) self.type_map = TypeMap(self.namespace_catalog) - self.type_map.register_container_type(CORE_NAMESPACE, 'Foo', Foo) - self.type_map.register_container_type(CORE_NAMESPACE, 'Bar', Bar) + self.type_map.register_container_type(CORE_NAMESPACE, "Foo", Foo) + self.type_map.register_container_type(CORE_NAMESPACE, "Bar", Bar) self.manager = BuildManager(self.type_map) self.foo_mapper = ObjectMapper(self.foo_spec) self.bar_mapper = ObjectMapper(self.bar_spec) def test_build_child_link(self): - ''' Test default mapping functionality when one container contains a child link to another container ''' - foo_inst = Foo('my_foo') - bar_inst1 = Bar('my_bar1', list(range(10)), 'value1', 10, foo=foo_inst) + """Test default mapping functionality when one container contains a child link to another container""" + foo_inst = Foo("my_foo") + bar_inst1 = Bar("my_bar1", list(range(10)), "value1", 10, foo=foo_inst) # bar_inst2.foo should link to bar_inst1.foo - bar_inst2 = Bar('my_bar2', list(range(10)), 'value1', 10, foo=foo_inst) + bar_inst2 = Bar("my_bar2", list(range(10)), "value1", 10, foo=foo_inst) foo_builder = self.foo_mapper.build(foo_inst, self.manager) bar1_builder = self.bar_mapper.build(bar_inst1, self.manager) bar2_builder = self.bar_mapper.build(bar_inst2, self.manager) - foo_expected = GroupBuilder('my_foo') + foo_expected = GroupBuilder("my_foo") - inner_foo_builder = GroupBuilder('my_foo', - attributes={'data_type': 'Foo', - 'namespace': CORE_NAMESPACE, - 'object_id': foo_inst.object_id}) - bar1_expected = GroupBuilder('my_bar1', - datasets={'data': DatasetBuilder('data', list(range(10)))}, - groups={'foo': inner_foo_builder}, - attributes={'attr1': 'value1', 'attr2': 10}) + inner_foo_builder = GroupBuilder( + "my_foo", + attributes={ + "data_type": "Foo", + "namespace": CORE_NAMESPACE, + "object_id": foo_inst.object_id, + }, + ) + bar1_expected = GroupBuilder( + "my_bar1", + datasets={"data": DatasetBuilder("data", list(range(10)))}, + groups={"foo": inner_foo_builder}, + attributes={"attr1": "value1", "attr2": 10}, + ) link_foo_builder = LinkBuilder(builder=inner_foo_builder) - bar2_expected = GroupBuilder('my_bar2', - datasets={'data': DatasetBuilder('data', list(range(10)))}, - links={'foo': link_foo_builder}, - attributes={'attr1': 'value1', 'attr2': 10}) + bar2_expected = GroupBuilder( + "my_bar2", + datasets={"data": DatasetBuilder("data", list(range(10)))}, + links={"foo": link_foo_builder}, + attributes={"attr1": "value1", "attr2": 10}, + ) self.assertBuilderEqual(foo_builder, foo_expected) self.assertBuilderEqual(bar1_builder, bar1_expected) self.assertBuilderEqual(bar2_builder, bar2_expected) @unittest.expectedFailure def test_build_broken_link_parent(self): - ''' Test that building a container with a broken link that has a parent raises an error. ''' - foo_inst = Foo('my_foo') - Bar('my_bar1', list(range(10)), 'value1', 10, foo=foo_inst) # foo_inst.parent is this bar + """Test that building a container with a broken link that has a parent raises an error.""" + foo_inst = Foo("my_foo") + Bar("my_bar1", list(range(10)), "value1", 10, foo=foo_inst) # foo_inst.parent is this bar # bar_inst2.foo should link to bar_inst1.foo - bar_inst2 = Bar('my_bar2', list(range(10)), 'value1', 10, foo=foo_inst) + bar_inst2 = Bar("my_bar2", list(range(10)), "value1", 10, foo=foo_inst) # TODO bar_inst.foo.parent exists but is never built - this is a tricky edge case that should raise an error with self.assertRaises(OrphanContainerBuildError): self.bar_mapper.build(bar_inst2, self.manager) def test_build_broken_link_no_parent(self): - ''' Test that building a container with a broken link that has no parent raises an error. ''' - foo_inst = Foo('my_foo') - bar_inst1 = Bar('my_bar1', list(range(10)), 'value1', 10, foo=foo_inst) # foo_inst.parent is this bar + """Test that building a container with a broken link that has no parent raises an error.""" + foo_inst = Foo("my_foo") + bar_inst1 = Bar("my_bar1", list(range(10)), "value1", 10, foo=foo_inst) # foo_inst.parent is this bar # bar_inst2.foo should link to bar_inst1.foo - bar_inst2 = Bar('my_bar2', list(range(10)), 'value1', 10, foo=foo_inst) + bar_inst2 = Bar("my_bar2", list(range(10)), "value1", 10, foo=foo_inst) bar_inst1.remove_foo() - msg = ("my_bar2 (my_bar2): Linked Foo 'my_foo' has no parent. Remove the link or ensure the linked container " - "is added properly.") + msg = ( + "my_bar2 (my_bar2): Linked Foo 'my_foo' has no parent. Remove the link or" + " ensure the linked container is added properly." + ) with self.assertRaisesWith(OrphanContainerBuildError, msg): self.bar_mapper.build(bar_inst2, self.manager) class TestReference(TestCase): - def setUp(self): - self.foo_spec = GroupSpec('A test group specification with data type Foo', data_type_def='Foo') - self.bar_spec = GroupSpec('A test group specification with a data type Bar', - data_type_def='Bar', - datasets=[DatasetSpec('an example dataset', 'int', name='data')], - attributes=[AttributeSpec('attr1', 'an example string attribute', 'text'), - AttributeSpec('attr2', 'an example integer attribute', 'int'), - AttributeSpec('foo', 'a referenced foo', RefSpec('Foo', 'object'), - required=False)]) + self.foo_spec = GroupSpec("A test group specification with data type Foo", data_type_def="Foo") + self.bar_spec = GroupSpec( + "A test group specification with a data type Bar", + data_type_def="Bar", + datasets=[DatasetSpec("an example dataset", "int", name="data")], + attributes=[ + AttributeSpec("attr1", "an example string attribute", "text"), + AttributeSpec("attr2", "an example integer attribute", "int"), + AttributeSpec( + "foo", + "a referenced foo", + RefSpec("Foo", "object"), + required=False, + ), + ], + ) self.spec_catalog = SpecCatalog() - self.spec_catalog.register_spec(self.foo_spec, 'test.yaml') - self.spec_catalog.register_spec(self.bar_spec, 'test.yaml') - self.namespace = SpecNamespace('a test namespace', CORE_NAMESPACE, - [{'source': 'test.yaml'}], - version='0.1.0', - catalog=self.spec_catalog) + self.spec_catalog.register_spec(self.foo_spec, "test.yaml") + self.spec_catalog.register_spec(self.bar_spec, "test.yaml") + self.namespace = SpecNamespace( + "a test namespace", + CORE_NAMESPACE, + [{"source": "test.yaml"}], + version="0.1.0", + catalog=self.spec_catalog, + ) self.namespace_catalog = NamespaceCatalog() self.namespace_catalog.add_namespace(CORE_NAMESPACE, self.namespace) self.type_map = TypeMap(self.namespace_catalog) - self.type_map.register_container_type(CORE_NAMESPACE, 'Foo', Foo) - self.type_map.register_container_type(CORE_NAMESPACE, 'Bar', Bar) + self.type_map.register_container_type(CORE_NAMESPACE, "Foo", Foo) + self.type_map.register_container_type(CORE_NAMESPACE, "Bar", Bar) self.manager = BuildManager(self.type_map) self.foo_mapper = ObjectMapper(self.foo_spec) self.bar_mapper = ObjectMapper(self.bar_spec) def test_build_attr_ref(self): - ''' Test default mapping functionality when one container contains an attribute reference to another container. - ''' - foo_inst = Foo('my_foo') - bar_inst1 = Bar('my_bar1', list(range(10)), 'value1', 10, foo=foo_inst) - bar_inst2 = Bar('my_bar2', list(range(10)), 'value1', 10) + """Test default mapping functionality when one container contains an attr reference to another container.""" + foo_inst = Foo("my_foo") + bar_inst1 = Bar("my_bar1", list(range(10)), "value1", 10, foo=foo_inst) + bar_inst2 = Bar("my_bar2", list(range(10)), "value1", 10) foo_builder = self.manager.build(foo_inst, root=True) bar1_builder = self.manager.build(bar_inst1, root=True) # adds refs bar2_builder = self.manager.build(bar_inst2, root=True) - foo_expected = GroupBuilder('my_foo', - attributes={'data_type': 'Foo', - 'namespace': CORE_NAMESPACE, - 'object_id': foo_inst.object_id}) - bar1_expected = GroupBuilder('n/a', # name doesn't matter - datasets={'data': DatasetBuilder('data', list(range(10)))}, - attributes={'attr1': 'value1', - 'attr2': 10, - 'foo': ReferenceBuilder(foo_expected), - 'data_type': 'Bar', - 'namespace': CORE_NAMESPACE, - 'object_id': bar_inst1.object_id}) - bar2_expected = GroupBuilder('n/a', # name doesn't matter - datasets={'data': DatasetBuilder('data', list(range(10)))}, - attributes={'attr1': 'value1', - 'attr2': 10, - 'data_type': 'Bar', - 'namespace': CORE_NAMESPACE, - 'object_id': bar_inst2.object_id}) + foo_expected = GroupBuilder( + "my_foo", + attributes={ + "data_type": "Foo", + "namespace": CORE_NAMESPACE, + "object_id": foo_inst.object_id, + }, + ) + bar1_expected = GroupBuilder( + "n/a", # name doesn't matter + datasets={"data": DatasetBuilder("data", list(range(10)))}, + attributes={ + "attr1": "value1", + "attr2": 10, + "foo": ReferenceBuilder(foo_expected), + "data_type": "Bar", + "namespace": CORE_NAMESPACE, + "object_id": bar_inst1.object_id, + }, + ) + bar2_expected = GroupBuilder( + "n/a", # name doesn't matter + datasets={"data": DatasetBuilder("data", list(range(10)))}, + attributes={ + "attr1": "value1", + "attr2": 10, + "data_type": "Bar", + "namespace": CORE_NAMESPACE, + "object_id": bar_inst2.object_id, + }, + ) self.assertDictEqual(foo_builder, foo_expected) self.assertDictEqual(bar1_builder, bar1_expected) self.assertDictEqual(bar2_builder, bar2_expected) def test_build_attr_ref_invalid(self): - ''' Test default mapping functionality when one container contains an attribute reference to another container. - ''' - bar_inst1 = Bar('my_bar1', list(range(10)), 'value1', 10) + """Test default mapping functionality when one container contains an attr reference to another container.""" + bar_inst1 = Bar("my_bar1", list(range(10)), "value1", 10) bar_inst1._Bar__foo = object() # make foo object a non-container type msg = "invalid type for reference 'foo' () - must be AbstractContainer" @@ -675,69 +780,65 @@ def test_build_attr_ref_invalid(self): class TestMissingRequiredAttribute(ObjectMapperMixin, TestCase): - def setUpBarSpec(self): self.bar_spec = GroupSpec( - doc='A test group specification with a data type Bar', - data_type_def='Bar', - attributes=[AttributeSpec('attr1', 'an example string attribute', 'text'), - AttributeSpec('attr2', 'an example integer attribute', 'int')] + doc="A test group specification with a data type Bar", + data_type_def="Bar", + attributes=[ + AttributeSpec("attr1", "an example string attribute", "text"), + AttributeSpec("attr2", "an example integer attribute", "int"), + ], ) def test_required_attr_missing(self): """Test mapping when one container is missing a required attribute.""" - bar_inst1 = Bar('my_bar1', list(range(10)), 'value1', 10) + bar_inst1 = Bar("my_bar1", list(range(10)), "value1", 10) bar_inst1._Bar__attr1 = None # make attr1 attribute None msg = "Bar 'my_bar1' is missing required value for attribute 'attr1'." with self.assertWarnsWith(MissingRequiredBuildWarning, msg): builder = self.mapper.build(bar_inst1, self.manager) - expected = GroupBuilder( - name='my_bar1', - attributes={'attr2': 10} - ) + expected = GroupBuilder(name="my_bar1", attributes={"attr2": 10}) self.assertBuilderEqual(expected, builder) class TestMissingRequiredAttributeRef(ObjectMapperMixin, TestCase): - def setUpBarSpec(self): self.bar_spec = GroupSpec( - doc='A test group specification with a data type Bar', - data_type_def='Bar', - attributes=[AttributeSpec('foo', 'a referenced foo', RefSpec('Foo', 'object'))] + doc="A test group specification with a data type Bar", + data_type_def="Bar", + attributes=[AttributeSpec("foo", "a referenced foo", RefSpec("Foo", "object"))], ) def test_required_attr_ref_missing(self): """Test mapping when one container is missing a required attribute reference.""" - bar_inst1 = Bar('my_bar1', list(range(10)), 'value1', 10) + bar_inst1 = Bar("my_bar1", list(range(10)), "value1", 10) msg = "Bar 'my_bar1' is missing required value for attribute 'foo'." with self.assertWarnsWith(MissingRequiredBuildWarning, msg): builder = self.mapper.build(bar_inst1, self.manager) expected = GroupBuilder( - name='my_bar1', + name="my_bar1", ) self.assertBuilderEqual(expected, builder) class TestMissingRequiredDataset(ObjectMapperMixin, TestCase): - def setUpBarSpec(self): self.bar_spec = GroupSpec( - doc='A test group specification with a data type Bar', - data_type_def='Bar', - datasets=[DatasetSpec('an example dataset', 'int', name='data')] + doc="A test group specification with a data type Bar", + data_type_def="Bar", + datasets=[DatasetSpec("an example dataset", "int", name="data")], ) def test_required_dataset_missing(self): """Test mapping when one container is missing a required dataset.""" - bar_inst1 = Bar('my_bar1', list(range(10)), 'value1', 10) + bar_inst1 = Bar("my_bar1", list(range(10)), "value1", 10) bar_inst1._Bar__data = None # make data dataset None msg = "Bar 'my_bar1' is missing required value for attribute 'data'." @@ -745,136 +846,144 @@ def test_required_dataset_missing(self): builder = self.mapper.build(bar_inst1, self.manager) expected = GroupBuilder( - name='my_bar1', + name="my_bar1", ) self.assertBuilderEqual(expected, builder) class TestMissingRequiredGroup(ObjectMapperMixin, TestCase): - def setUpBarSpec(self): self.bar_spec = GroupSpec( - doc='A test group specification with a data type Bar', - data_type_def='Bar', - groups=[GroupSpec('foo', data_type_inc='Foo')] + doc="A test group specification with a data type Bar", + data_type_def="Bar", + groups=[GroupSpec("foo", data_type_inc="Foo")], ) def test_required_group_missing(self): """Test mapping when one container is missing a required group.""" - bar_inst1 = Bar('my_bar1', list(range(10)), 'value1', 10) + bar_inst1 = Bar("my_bar1", list(range(10)), "value1", 10) msg = "Bar 'my_bar1' is missing required value for attribute 'foo'." with self.assertWarnsWith(MissingRequiredBuildWarning, msg): builder = self.mapper.build(bar_inst1, self.manager) expected = GroupBuilder( - name='my_bar1', + name="my_bar1", ) self.assertBuilderEqual(expected, builder) class TestRequiredEmptyGroup(ObjectMapperMixin, TestCase): - def setUpBarSpec(self): self.bar_spec = GroupSpec( - doc='A test group specification with a data type Bar', - data_type_def='Bar', - groups=[GroupSpec(name='empty', doc='empty group')], + doc="A test group specification with a data type Bar", + data_type_def="Bar", + groups=[GroupSpec(name="empty", doc="empty group")], ) def test_required_group_empty(self): """Test mapping when one container has a required empty group.""" - bar_inst1 = Bar('my_bar1', list(range(10)), 'value1', 10) + bar_inst1 = Bar("my_bar1", list(range(10)), "value1", 10) builder = self.mapper.build(bar_inst1, self.manager) expected = GroupBuilder( - name='my_bar1', - groups={'empty': GroupBuilder('empty')}, + name="my_bar1", + groups={"empty": GroupBuilder("empty")}, ) self.assertBuilderEqual(expected, builder) class TestOptionalEmptyGroup(ObjectMapperMixin, TestCase): - def setUpBarSpec(self): self.bar_spec = GroupSpec( - doc='A test group specification with a data type Bar', - data_type_def='Bar', - groups=[GroupSpec( - name='empty', - doc='empty group', - quantity='?', - attributes=[AttributeSpec('attr3', 'an optional float attribute', 'float', required=False)] - )] + doc="A test group specification with a data type Bar", + data_type_def="Bar", + groups=[ + GroupSpec( + name="empty", + doc="empty group", + quantity="?", + attributes=[ + AttributeSpec( + "attr3", + "an optional float attribute", + "float", + required=False, + ) + ], + ) + ], ) def test_optional_group_empty(self): """Test mapping when one container has an optional empty group.""" - self.mapper.map_spec('attr3', self.mapper.spec.get_group('empty').get_attribute('attr3')) + self.mapper.map_spec("attr3", self.mapper.spec.get_group("empty").get_attribute("attr3")) - bar_inst1 = Bar('my_bar1', list(range(10)), 'value1', 10) + bar_inst1 = Bar("my_bar1", list(range(10)), "value1", 10) bar_inst1._Bar__attr3 = None # force attr3 to be None builder = self.mapper.build(bar_inst1, self.manager) expected = GroupBuilder( - name='my_bar1', + name="my_bar1", ) self.assertBuilderEqual(expected, builder) def test_optional_group_not_empty(self): """Test mapping when one container has an optional not empty group.""" - self.mapper.map_spec('attr3', self.mapper.spec.get_group('empty').get_attribute('attr3')) + self.mapper.map_spec("attr3", self.mapper.spec.get_group("empty").get_attribute("attr3")) - bar_inst1 = Bar('my_bar1', list(range(10)), 'value1', 10, attr3=1.23) + bar_inst1 = Bar("my_bar1", list(range(10)), "value1", 10, attr3=1.23) builder = self.mapper.build(bar_inst1, self.manager) expected = GroupBuilder( - name='my_bar1', - groups={'empty': GroupBuilder( - name='empty', - attributes={'attr3': 1.23}, - )}, + name="my_bar1", + groups={ + "empty": GroupBuilder( + name="empty", + attributes={"attr3": 1.23}, + ) + }, ) self.assertBuilderEqual(expected, builder) class TestFixedAttributeValue(ObjectMapperMixin, TestCase): - def setUpBarSpec(self): self.bar_spec = GroupSpec( - doc='A test group specification with a data type Bar', - data_type_def='Bar', - attributes=[AttributeSpec('attr1', 'an example string attribute', 'text', value='hi'), - AttributeSpec('attr2', 'an example integer attribute', 'int')] + doc="A test group specification with a data type Bar", + data_type_def="Bar", + attributes=[ + AttributeSpec("attr1", "an example string attribute", "text", value="hi"), + AttributeSpec("attr2", "an example integer attribute", "int"), + ], ) def test_required_attr_missing(self): """Test mapping when one container has a required attribute with a fixed value.""" - bar_inst1 = Bar('my_bar1', list(range(10)), 'value1', 10) # attr1=value1 is not processed + bar_inst1 = Bar("my_bar1", list(range(10)), "value1", 10) # attr1=value1 is not processed builder = self.mapper.build(bar_inst1, self.manager) - expected = GroupBuilder( - name='my_bar1', - attributes={'attr1': 'hi', 'attr2': 10} - ) + expected = GroupBuilder(name="my_bar1", attributes={"attr1": "hi", "attr2": 10}) self.assertBuilderEqual(builder, expected) class TestObjectMapperBadValue(TestCase): - def test_bad_value(self): """Test that an error is raised if the container attribute value for a spec with a data type is not a container or collection of containers. """ + class Qux(Container): - @docval({'name': 'name', 'type': str, 'doc': 'the name of this Qux'}, - {'name': 'foo', 'type': int, 'doc': 'a group'}) + @docval( + {"name": "name", "type": str, "doc": "the name of this Qux"}, + {"name": "foo", "type": int, "doc": "a group"}, + ) def __init__(self, **kwargs): - name, foo = getargs('name', 'foo', kwargs) + name, foo = getargs("name", "foo", kwargs) super().__init__(name=name) self.__foo = foo if isinstance(foo, Foo): @@ -885,26 +994,30 @@ def foo(self): return self.__foo self.qux_spec = GroupSpec( - doc='A test group specification with data type Qux', - data_type_def='Qux', - groups=[GroupSpec('an example dataset', data_type_inc='Foo')] + doc="A test group specification with data type Qux", + data_type_def="Qux", + groups=[GroupSpec("an example dataset", data_type_inc="Foo")], ) - self.foo_spec = GroupSpec('A test group specification with data type Foo', data_type_def='Foo') + self.foo_spec = GroupSpec("A test group specification with data type Foo", data_type_def="Foo") self.spec_catalog = SpecCatalog() - self.spec_catalog.register_spec(self.qux_spec, 'test.yaml') - self.spec_catalog.register_spec(self.foo_spec, 'test.yaml') - self.namespace = SpecNamespace('a test namespace', CORE_NAMESPACE, [{'source': 'test.yaml'}], - version='0.1.0', - catalog=self.spec_catalog) + self.spec_catalog.register_spec(self.qux_spec, "test.yaml") + self.spec_catalog.register_spec(self.foo_spec, "test.yaml") + self.namespace = SpecNamespace( + "a test namespace", + CORE_NAMESPACE, + [{"source": "test.yaml"}], + version="0.1.0", + catalog=self.spec_catalog, + ) self.namespace_catalog = NamespaceCatalog() self.namespace_catalog.add_namespace(CORE_NAMESPACE, self.namespace) self.type_map = TypeMap(self.namespace_catalog) - self.type_map.register_container_type(CORE_NAMESPACE, 'Qux', Qux) - self.type_map.register_container_type(CORE_NAMESPACE, 'Foo', Foo) + self.type_map.register_container_type(CORE_NAMESPACE, "Qux", Qux) + self.type_map.register_container_type(CORE_NAMESPACE, "Foo", Foo) self.manager = BuildManager(self.type_map) self.mapper = ObjectMapper(self.qux_spec) - container = Qux('my_qux', foo=1) + container = Qux("my_qux", foo=1) msg = "Qux 'my_qux' attribute 'foo' has unexpected type." with self.assertRaisesWith(ContainerConfigurationError, msg): self.mapper.build(container, self.manager) diff --git a/tests/unit/build_tests/test_io_map_data.py b/tests/unit/build_tests/test_io_map_data.py index d9b474c56..6c9fc6dec 100644 --- a/tests/unit/build_tests/test_io_map_data.py +++ b/tests/unit/build_tests/test_io_map_data.py @@ -2,27 +2,44 @@ import h5py import numpy as np + from hdmf import Container, Data from hdmf.backends.hdf5 import H5DataIO -from hdmf.build import (GroupBuilder, DatasetBuilder, ObjectMapper, BuildManager, TypeMap, ReferenceBuilder, - ReferenceTargetNotBuiltError) +from hdmf.build import ( + BuildManager, + DatasetBuilder, + GroupBuilder, + ObjectMapper, + ReferenceBuilder, + ReferenceTargetNotBuiltError, + TypeMap, +) from hdmf.data_utils import DataChunkIterator -from hdmf.spec import (AttributeSpec, DatasetSpec, DtypeSpec, GroupSpec, SpecCatalog, SpecNamespace, NamespaceCatalog, - RefSpec) +from hdmf.spec import ( + AttributeSpec, + DatasetSpec, + DtypeSpec, + GroupSpec, + NamespaceCatalog, + RefSpec, + SpecCatalog, + SpecNamespace, +) from hdmf.spec.spec import ZERO_OR_MANY from hdmf.testing import TestCase from hdmf.utils import docval, getargs -from tests.unit.helpers.utils import Foo, CORE_NAMESPACE +from ..helpers.utils import CORE_NAMESPACE, Foo class Baz(Data): - - @docval({'name': 'name', 'type': str, 'doc': 'the name of this Baz'}, - {'name': 'data', 'type': (list, h5py.Dataset, 'data', 'array_data'), 'doc': 'some data'}, - {'name': 'baz_attr', 'type': str, 'doc': 'an attribute'}) + @docval( + {"name": "name", "type": str, "doc": "the name of this Baz"}, + {"name": "data", "type": (list, h5py.Dataset, "data", "array_data"), "doc": "some data"}, + {"name": "baz_attr", "type": str, "doc": "an attribute"}, + ) def __init__(self, **kwargs): - name, data, baz_attr = getargs('name', 'data', 'baz_attr', kwargs) + name, data, baz_attr = getargs("name", "data", "baz_attr", kwargs) super().__init__(name=name, data=data) self.__baz_attr = baz_attr @@ -32,11 +49,12 @@ def baz_attr(self): class BazHolder(Container): - - @docval({'name': 'name', 'type': str, 'doc': 'the name of this Baz'}, - {'name': 'bazs', 'type': list, 'doc': 'some Baz data', 'default': list()}) + @docval( + {"name": "name", "type": str, "doc": "the name of this Baz"}, + {"name": "bazs", "type": list, "doc": "some Baz data", "default": list()}, + ) def __init__(self, **kwargs): - name, bazs = getargs('name', 'bazs', kwargs) + name, bazs = getargs("name", "bazs", kwargs) super().__init__(name=name) self.__bazs = {b.name: b for b in bazs} # note: collections of groups are unordered in HDF5 for b in bazs: @@ -48,411 +66,460 @@ def bazs(self): class BazSpecMixin: - def setUp(self): self.setUpBazSpec() self.spec_catalog = SpecCatalog() - self.spec_catalog.register_spec(self.baz_spec, 'test.yaml') - self.namespace = SpecNamespace('a test namespace', CORE_NAMESPACE, [{'source': 'test.yaml'}], - version='0.1.0', - catalog=self.spec_catalog) + self.spec_catalog.register_spec(self.baz_spec, "test.yaml") + self.namespace = SpecNamespace( + "a test namespace", + CORE_NAMESPACE, + [{"source": "test.yaml"}], + version="0.1.0", + catalog=self.spec_catalog, + ) self.namespace_catalog = NamespaceCatalog() self.namespace_catalog.add_namespace(CORE_NAMESPACE, self.namespace) self.type_map = TypeMap(self.namespace_catalog) - self.type_map.register_container_type(CORE_NAMESPACE, 'Baz', Baz) + self.type_map.register_container_type(CORE_NAMESPACE, "Baz", Baz) self.type_map.register_map(Baz, ObjectMapper) self.manager = BuildManager(self.type_map) self.mapper = ObjectMapper(self.baz_spec) def setUpBazSpec(self): - raise NotImplementedError('Test must implement this method.') + raise NotImplementedError("Test must implement this method.") class TestDataMap(BazSpecMixin, TestCase): - def setUp(self): self.setUpBazSpec() self.spec_catalog = SpecCatalog() - self.spec_catalog.register_spec(self.baz_spec, 'test.yaml') - self.namespace = SpecNamespace('a test namespace', CORE_NAMESPACE, [{'source': 'test.yaml'}], - version='0.1.0', - catalog=self.spec_catalog) + self.spec_catalog.register_spec(self.baz_spec, "test.yaml") + self.namespace = SpecNamespace( + "a test namespace", + CORE_NAMESPACE, + [{"source": "test.yaml"}], + version="0.1.0", + catalog=self.spec_catalog, + ) self.namespace_catalog = NamespaceCatalog() self.namespace_catalog.add_namespace(CORE_NAMESPACE, self.namespace) self.type_map = TypeMap(self.namespace_catalog) - self.type_map.register_container_type(CORE_NAMESPACE, 'Baz', Baz) + self.type_map.register_container_type(CORE_NAMESPACE, "Baz", Baz) self.type_map.register_map(Baz, ObjectMapper) self.manager = BuildManager(self.type_map) self.mapper = ObjectMapper(self.baz_spec) def setUpBazSpec(self): self.baz_spec = DatasetSpec( - doc='an Baz type', - dtype='int', - name='MyBaz', - data_type_def='Baz', + doc="an Baz type", + dtype="int", + name="MyBaz", + data_type_def="Baz", shape=[None], - attributes=[AttributeSpec('baz_attr', 'an example string attribute', 'text')] + attributes=[AttributeSpec("baz_attr", "an example string attribute", "text")], ) def test_build(self): - ''' Test default mapping functionality when no attributes are nested ''' - container = Baz('MyBaz', list(range(10)), 'abcdefghijklmnopqrstuvwxyz') + """Test default mapping functionality when no attributes are nested""" + container = Baz("MyBaz", list(range(10)), "abcdefghijklmnopqrstuvwxyz") builder = self.mapper.build(container, self.manager) - expected = DatasetBuilder('MyBaz', list(range(10)), attributes={'baz_attr': 'abcdefghijklmnopqrstuvwxyz'}) + expected = DatasetBuilder( + "MyBaz", + list(range(10)), + attributes={"baz_attr": "abcdefghijklmnopqrstuvwxyz"}, + ) self.assertBuilderEqual(builder, expected) def test_build_empty_data(self): """Test building of a Data object with empty data.""" - baz_inc_spec = DatasetSpec(doc='doc', data_type_inc='Baz', quantity=ZERO_OR_MANY) - baz_holder_spec = GroupSpec(doc='doc', data_type_def='BazHolder', datasets=[baz_inc_spec]) - self.spec_catalog.register_spec(baz_holder_spec, 'test.yaml') - self.type_map.register_container_type(CORE_NAMESPACE, 'BazHolder', BazHolder) + baz_inc_spec = DatasetSpec(doc="doc", data_type_inc="Baz", quantity=ZERO_OR_MANY) + baz_holder_spec = GroupSpec(doc="doc", data_type_def="BazHolder", datasets=[baz_inc_spec]) + self.spec_catalog.register_spec(baz_holder_spec, "test.yaml") + self.type_map.register_container_type(CORE_NAMESPACE, "BazHolder", BazHolder) self.holder_mapper = ObjectMapper(baz_holder_spec) - baz = Baz('MyBaz', [], 'abcdefghijklmnopqrstuvwxyz') - holder = BazHolder('holder', [baz]) + baz = Baz("MyBaz", [], "abcdefghijklmnopqrstuvwxyz") + holder = BazHolder("holder", [baz]) builder = self.holder_mapper.build(holder, self.manager) expected = GroupBuilder( - name='holder', - datasets=[DatasetBuilder( - name='MyBaz', - data=[], - attributes={'baz_attr': 'abcdefghijklmnopqrstuvwxyz', - 'data_type': 'Baz', - 'namespace': 'test_core', - 'object_id': baz.object_id} - )] + name="holder", + datasets=[ + DatasetBuilder( + name="MyBaz", + data=[], + attributes={ + "baz_attr": "abcdefghijklmnopqrstuvwxyz", + "data_type": "Baz", + "namespace": "test_core", + "object_id": baz.object_id, + }, + ) + ], ) self.assertBuilderEqual(builder, expected) def test_append(self): - with h5py.File('test.h5', 'w') as file: - test_ds = file.create_dataset('test_ds', data=[1, 2, 3], chunks=True, maxshape=(None,)) - container = Baz('MyBaz', test_ds, 'abcdefghijklmnopqrstuvwxyz') + with h5py.File("test.h5", "w") as file: + test_ds = file.create_dataset("test_ds", data=[1, 2, 3], chunks=True, maxshape=(None,)) + container = Baz("MyBaz", test_ds, "abcdefghijklmnopqrstuvwxyz") container.append(4) np.testing.assert_array_equal(container[:], [1, 2, 3, 4]) - os.remove('test.h5') + os.remove("test.h5") def test_extend(self): - with h5py.File('test.h5', 'w') as file: - test_ds = file.create_dataset('test_ds', data=[1, 2, 3], chunks=True, maxshape=(None,)) - container = Baz('MyBaz', test_ds, 'abcdefghijklmnopqrstuvwxyz') + with h5py.File("test.h5", "w") as file: + test_ds = file.create_dataset("test_ds", data=[1, 2, 3], chunks=True, maxshape=(None,)) + container = Baz("MyBaz", test_ds, "abcdefghijklmnopqrstuvwxyz") container.extend([4, 5]) np.testing.assert_array_equal(container[:], [1, 2, 3, 4, 5]) - os.remove('test.h5') + os.remove("test.h5") class BazScalar(Data): - - @docval({'name': 'name', 'type': str, 'doc': 'the name of this BazScalar'}, - {'name': 'data', 'type': int, 'doc': 'some data'}) + @docval( + {"name": "name", "type": str, "doc": "the name of this BazScalar"}, + {"name": "data", "type": int, "doc": "some data"}, + ) def __init__(self, **kwargs): super().__init__(**kwargs) class TestDataMapScalar(TestCase): - def setUp(self): self.setUpBazSpec() self.spec_catalog = SpecCatalog() - self.spec_catalog.register_spec(self.baz_spec, 'test.yaml') - self.namespace = SpecNamespace('a test namespace', CORE_NAMESPACE, [{'source': 'test.yaml'}], - version='0.1.0', - catalog=self.spec_catalog) + self.spec_catalog.register_spec(self.baz_spec, "test.yaml") + self.namespace = SpecNamespace( + "a test namespace", + CORE_NAMESPACE, + [{"source": "test.yaml"}], + version="0.1.0", + catalog=self.spec_catalog, + ) self.namespace_catalog = NamespaceCatalog() self.namespace_catalog.add_namespace(CORE_NAMESPACE, self.namespace) self.type_map = TypeMap(self.namespace_catalog) - self.type_map.register_container_type(CORE_NAMESPACE, 'BazScalar', BazScalar) + self.type_map.register_container_type(CORE_NAMESPACE, "BazScalar", BazScalar) self.type_map.register_map(BazScalar, ObjectMapper) self.manager = BuildManager(self.type_map) self.mapper = ObjectMapper(self.baz_spec) def setUpBazSpec(self): self.baz_spec = DatasetSpec( - doc='a BazScalar type', - dtype='int', - name='MyBaz', - data_type_def='BazScalar' + doc="a BazScalar type", + dtype="int", + name="MyBaz", + data_type_def="BazScalar", ) def test_construct_scalar_dataset(self): """Test constructing a Data object with an h5py.Dataset with shape (1, ) for scalar spec.""" - with h5py.File('test.h5', 'w') as file: - test_ds = file.create_dataset('test_ds', data=[1]) + with h5py.File("test.h5", "w") as file: + test_ds = file.create_dataset("test_ds", data=[1]) expected = BazScalar( - name='MyBaz', + name="MyBaz", data=1, ) builder = DatasetBuilder( - name='MyBaz', + name="MyBaz", data=test_ds, - attributes={'data_type': 'BazScalar', - 'namespace': CORE_NAMESPACE, - 'object_id': expected.object_id}, + attributes={ + "data_type": "BazScalar", + "namespace": CORE_NAMESPACE, + "object_id": expected.object_id, + }, ) container = self.mapper.construct(builder, self.manager) self.assertTrue(np.issubdtype(type(container.data), np.integer)) # as opposed to h5py.Dataset self.assertContainerEqual(container, expected) - os.remove('test.h5') + os.remove("test.h5") class BazScalarCompound(Data): - - @docval({'name': 'name', 'type': str, 'doc': 'the name of this BazScalar'}, - {'name': 'data', 'type': 'array_data', 'doc': 'some data'}) + @docval( + {"name": "name", "type": str, "doc": "the name of this BazScalar"}, + {"name": "data", "type": "array_data", "doc": "some data"}, + ) def __init__(self, **kwargs): super().__init__(**kwargs) class TestDataMapScalarCompound(TestCase): - def setUp(self): self.setUpBazSpec() self.spec_catalog = SpecCatalog() - self.spec_catalog.register_spec(self.baz_spec, 'test.yaml') - self.namespace = SpecNamespace('a test namespace', CORE_NAMESPACE, [{'source': 'test.yaml'}], - version='0.1.0', - catalog=self.spec_catalog) + self.spec_catalog.register_spec(self.baz_spec, "test.yaml") + self.namespace = SpecNamespace( + "a test namespace", + CORE_NAMESPACE, + [{"source": "test.yaml"}], + version="0.1.0", + catalog=self.spec_catalog, + ) self.namespace_catalog = NamespaceCatalog() self.namespace_catalog.add_namespace(CORE_NAMESPACE, self.namespace) self.type_map = TypeMap(self.namespace_catalog) - self.type_map.register_container_type(CORE_NAMESPACE, 'BazScalarCompound', BazScalarCompound) + self.type_map.register_container_type(CORE_NAMESPACE, "BazScalarCompound", BazScalarCompound) self.type_map.register_map(BazScalarCompound, ObjectMapper) self.manager = BuildManager(self.type_map) self.mapper = ObjectMapper(self.baz_spec) def setUpBazSpec(self): self.baz_spec = DatasetSpec( - doc='a BazScalarCompound type', + doc="a BazScalarCompound type", dtype=[ DtypeSpec( - name='id', - dtype='uint64', - doc='The unique identifier in this table.' - ), - DtypeSpec( - name='attr1', - dtype='text', - doc='A text attribute.' + name="id", + dtype="uint64", + doc="The unique identifier in this table.", ), + DtypeSpec(name="attr1", dtype="text", doc="A text attribute."), ], - name='MyBaz', - data_type_def='BazScalarCompound', + name="MyBaz", + data_type_def="BazScalarCompound", ) def test_construct_scalar_compound_dataset(self): """Test construct on a compound h5py.Dataset with shape (1, ) for scalar spec does not resolve the data.""" - with h5py.File('test.h5', 'w') as file: - comp_type = np.dtype([('id', np.uint64), ('attr1', h5py.special_dtype(vlen=str))]) + with h5py.File("test.h5", "w") as file: + comp_type = np.dtype([("id", np.uint64), ("attr1", h5py.special_dtype(vlen=str))]) test_ds = file.create_dataset( - name='test_ds', - data=np.array((1, 'text'), dtype=comp_type), - shape=(1, ), - dtype=comp_type + name="test_ds", + data=np.array((1, "text"), dtype=comp_type), + shape=(1,), + dtype=comp_type, ) expected = BazScalarCompound( - name='MyBaz', - data=(1, 'text'), + name="MyBaz", + data=(1, "text"), ) builder = DatasetBuilder( - name='MyBaz', + name="MyBaz", data=test_ds, - attributes={'data_type': 'BazScalarCompound', - 'namespace': CORE_NAMESPACE, - 'object_id': expected.object_id}, + attributes={ + "data_type": "BazScalarCompound", + "namespace": CORE_NAMESPACE, + "object_id": expected.object_id, + }, ) container = self.mapper.construct(builder, self.manager) self.assertEqual(type(container.data), h5py.Dataset) self.assertContainerEqual(container, expected) - os.remove('test.h5') + os.remove("test.h5") class BuildDatasetOfReferencesMixin: - def setUp(self): self.setUpBazSpec() self.foo_spec = GroupSpec( - doc='A test group specification with a data type', - data_type_def='Foo', - datasets=[ - DatasetSpec(name='my_data', doc='an example dataset', dtype='int') - ], + doc="A test group specification with a data type", + data_type_def="Foo", + datasets=[DatasetSpec(name="my_data", doc="an example dataset", dtype="int")], attributes=[ - AttributeSpec(name='attr1', doc='an example string attribute', dtype='text'), - AttributeSpec(name='attr2', doc='an example int attribute', dtype='int'), - AttributeSpec(name='attr3', doc='an example float attribute', dtype='float') - ] + AttributeSpec( + name="attr1", + doc="an example string attribute", + dtype="text", + ), + AttributeSpec(name="attr2", doc="an example int attribute", dtype="int"), + AttributeSpec( + name="attr3", + doc="an example float attribute", + dtype="float", + ), + ], ) self.spec_catalog = SpecCatalog() - self.spec_catalog.register_spec(self.baz_spec, 'test.yaml') - self.spec_catalog.register_spec(self.foo_spec, 'test.yaml') - self.namespace = SpecNamespace('a test namespace', CORE_NAMESPACE, [{'source': 'test.yaml'}], - version='0.1.0', - catalog=self.spec_catalog) + self.spec_catalog.register_spec(self.baz_spec, "test.yaml") + self.spec_catalog.register_spec(self.foo_spec, "test.yaml") + self.namespace = SpecNamespace( + "a test namespace", + CORE_NAMESPACE, + [{"source": "test.yaml"}], + version="0.1.0", + catalog=self.spec_catalog, + ) self.namespace_catalog = NamespaceCatalog() self.namespace_catalog.add_namespace(CORE_NAMESPACE, self.namespace) self.type_map = TypeMap(self.namespace_catalog) - self.type_map.register_container_type(CORE_NAMESPACE, 'Baz', Baz) - self.type_map.register_container_type(CORE_NAMESPACE, 'Foo', Foo) + self.type_map.register_container_type(CORE_NAMESPACE, "Baz", Baz) + self.type_map.register_container_type(CORE_NAMESPACE, "Foo", Foo) self.type_map.register_map(Baz, ObjectMapper) self.type_map.register_map(Foo, ObjectMapper) self.manager = BuildManager(self.type_map) class TestBuildUntypedDatasetOfReferences(BuildDatasetOfReferencesMixin, TestCase): - def setUpBazSpec(self): self.baz_spec = DatasetSpec( - doc='a list of references to Foo objects', + doc="a list of references to Foo objects", dtype=None, - name='MyBaz', + name="MyBaz", shape=[None], - data_type_def='Baz', - attributes=[AttributeSpec('baz_attr', 'an example string attribute', 'text')] + data_type_def="Baz", + attributes=[AttributeSpec("baz_attr", "an example string attribute", "text")], ) def test_build(self): - ''' Test default mapping functionality when no attributes are nested ''' - foo = Foo('my_foo1', [1, 2, 3], 'string', 10) - baz = Baz('MyBaz', [foo, None], 'abcdefghijklmnopqrstuvwxyz') + """Test default mapping functionality when no attributes are nested""" + foo = Foo("my_foo1", [1, 2, 3], "string", 10) + baz = Baz("MyBaz", [foo, None], "abcdefghijklmnopqrstuvwxyz") foo_builder = self.manager.build(foo) baz_builder = self.manager.build(baz, root=True) - expected = DatasetBuilder('MyBaz', [ReferenceBuilder(foo_builder), None], - attributes={'baz_attr': 'abcdefghijklmnopqrstuvwxyz', - 'data_type': 'Baz', - 'namespace': CORE_NAMESPACE, - 'object_id': baz.object_id}) + expected = DatasetBuilder( + "MyBaz", + [ReferenceBuilder(foo_builder), None], + attributes={ + "baz_attr": "abcdefghijklmnopqrstuvwxyz", + "data_type": "Baz", + "namespace": CORE_NAMESPACE, + "object_id": baz.object_id, + }, + ) self.assertBuilderEqual(baz_builder, expected) class TestBuildCompoundDatasetOfReferences(BuildDatasetOfReferencesMixin, TestCase): - def setUpBazSpec(self): self.baz_spec = DatasetSpec( - doc='a list of references to Foo objects', + doc="a list of references to Foo objects", dtype=[ DtypeSpec( - name='id', - dtype='uint64', - doc='The unique identifier in this table.' + name="id", + dtype="uint64", + doc="The unique identifier in this table.", ), DtypeSpec( - name='foo', - dtype=RefSpec('Foo', 'object'), - doc='The foo in this table.' + name="foo", + dtype=RefSpec("Foo", "object"), + doc="The foo in this table.", ), ], - name='MyBaz', + name="MyBaz", shape=[None], - data_type_def='Baz', - attributes=[AttributeSpec('baz_attr', 'an example string attribute', 'text')] + data_type_def="Baz", + attributes=[AttributeSpec("baz_attr", "an example string attribute", "text")], ) def test_build(self): - ''' Test default mapping functionality when no attributes are nested ''' - foo = Foo('my_foo1', [1, 2, 3], 'string', 10) - baz = Baz('MyBaz', [(1, foo)], 'abcdefghijklmnopqrstuvwxyz') + """Test default mapping functionality when no attributes are nested""" + foo = Foo("my_foo1", [1, 2, 3], "string", 10) + baz = Baz("MyBaz", [(1, foo)], "abcdefghijklmnopqrstuvwxyz") foo_builder = self.manager.build(foo) baz_builder = self.manager.build(baz, root=True) - expected = DatasetBuilder('MyBaz', [(1, ReferenceBuilder(foo_builder))], - attributes={'baz_attr': 'abcdefghijklmnopqrstuvwxyz', - 'data_type': 'Baz', - 'namespace': CORE_NAMESPACE, - 'object_id': baz.object_id}) + expected = DatasetBuilder( + "MyBaz", + [(1, ReferenceBuilder(foo_builder))], + attributes={ + "baz_attr": "abcdefghijklmnopqrstuvwxyz", + "data_type": "Baz", + "namespace": CORE_NAMESPACE, + "object_id": baz.object_id, + }, + ) self.assertBuilderEqual(baz_builder, expected) class TestBuildTypedDatasetOfReferences(BuildDatasetOfReferencesMixin, TestCase): - def setUpBazSpec(self): self.baz_spec = DatasetSpec( - doc='a list of references to Foo objects', - dtype=RefSpec('Foo', 'object'), - name='MyBaz', + doc="a list of references to Foo objects", + dtype=RefSpec("Foo", "object"), + name="MyBaz", shape=[None], - data_type_def='Baz', - attributes=[AttributeSpec('baz_attr', 'an example string attribute', 'text')] + data_type_def="Baz", + attributes=[AttributeSpec("baz_attr", "an example string attribute", "text")], ) def test_build(self): - ''' Test default mapping functionality when no attributes are nested ''' - foo = Foo('my_foo1', [1, 2, 3], 'string', 10) - baz = Baz('MyBaz', [foo], 'abcdefghijklmnopqrstuvwxyz') + """Test default mapping functionality when no attributes are nested""" + foo = Foo("my_foo1", [1, 2, 3], "string", 10) + baz = Baz("MyBaz", [foo], "abcdefghijklmnopqrstuvwxyz") foo_builder = self.manager.build(foo) baz_builder = self.manager.build(baz, root=True) - expected = DatasetBuilder('MyBaz', [ReferenceBuilder(foo_builder)], - attributes={'baz_attr': 'abcdefghijklmnopqrstuvwxyz', - 'data_type': 'Baz', - 'namespace': CORE_NAMESPACE, - 'object_id': baz.object_id}) + expected = DatasetBuilder( + "MyBaz", + [ReferenceBuilder(foo_builder)], + attributes={ + "baz_attr": "abcdefghijklmnopqrstuvwxyz", + "data_type": "Baz", + "namespace": CORE_NAMESPACE, + "object_id": baz.object_id, + }, + ) self.assertBuilderEqual(baz_builder, expected) class TestBuildDatasetOfReferencesUnbuiltTarget(BuildDatasetOfReferencesMixin, TestCase): - def setUpBazSpec(self): self.baz_spec = DatasetSpec( - doc='a list of references to Foo objects', + doc="a list of references to Foo objects", dtype=None, - name='MyBaz', + name="MyBaz", shape=[None], - data_type_def='Baz', - attributes=[AttributeSpec('baz_attr', 'an example string attribute', 'text')] + data_type_def="Baz", + attributes=[AttributeSpec("baz_attr", "an example string attribute", "text")], ) def test_build(self): - ''' Test default mapping functionality when no attributes are nested ''' - foo = Foo('my_foo1', [1, 2, 3], 'string', 10) - baz = Baz('MyBaz', [foo], 'abcdefghijklmnopqrstuvwxyz') + """Test default mapping functionality when no attributes are nested""" + foo = Foo("my_foo1", [1, 2, 3], "string", 10) + baz = Baz("MyBaz", [foo], "abcdefghijklmnopqrstuvwxyz") msg = "MyBaz (MyBaz): Could not find already-built Builder for Foo 'my_foo1' in BuildManager" with self.assertRaisesWith(ReferenceTargetNotBuiltError, msg): self.manager.build(baz, root=True) class TestDataIOEdgeCases(TestCase): - def setUp(self): self.setUpBazSpec() self.spec_catalog = SpecCatalog() - self.spec_catalog.register_spec(self.baz_spec, 'test.yaml') - self.namespace = SpecNamespace('a test namespace', CORE_NAMESPACE, [{'source': 'test.yaml'}], - version='0.1.0', - catalog=self.spec_catalog) + self.spec_catalog.register_spec(self.baz_spec, "test.yaml") + self.namespace = SpecNamespace( + "a test namespace", + CORE_NAMESPACE, + [{"source": "test.yaml"}], + version="0.1.0", + catalog=self.spec_catalog, + ) self.namespace_catalog = NamespaceCatalog() self.namespace_catalog.add_namespace(CORE_NAMESPACE, self.namespace) self.type_map = TypeMap(self.namespace_catalog) - self.type_map.register_container_type(CORE_NAMESPACE, 'Baz', Baz) + self.type_map.register_container_type(CORE_NAMESPACE, "Baz", Baz) self.type_map.register_map(Baz, ObjectMapper) self.manager = BuildManager(self.type_map) self.mapper = ObjectMapper(self.baz_spec) def setUpBazSpec(self): self.baz_spec = DatasetSpec( - doc='an Baz type', + doc="an Baz type", dtype=None, - name='MyBaz', - data_type_def='Baz', + name="MyBaz", + data_type_def="Baz", shape=[None], - attributes=[AttributeSpec('baz_attr', 'an example string attribute', 'text')] + attributes=[AttributeSpec("baz_attr", "an example string attribute", "text")], ) def test_build_dataio(self): """Test building of a dataset with data_type and no dtype with value DataIO.""" - container = Baz('my_baz', H5DataIO(['a', 'b', 'c', 'd'], chunks=True), 'value1') + container = Baz("my_baz", H5DataIO(["a", "b", "c", "d"], chunks=True), "value1") builder = self.type_map.build(container) - self.assertIsInstance(builder.get('data'), H5DataIO) + self.assertIsInstance(builder.get("data"), H5DataIO) def test_build_datachunkiterator(self): """Test building of a dataset with data_type and no dtype with value DataChunkIterator.""" - container = Baz('my_baz', DataChunkIterator(['a', 'b', 'c', 'd']), 'value1') + container = Baz("my_baz", DataChunkIterator(["a", "b", "c", "d"]), "value1") builder = self.type_map.build(container) - self.assertIsInstance(builder.get('data'), DataChunkIterator) + self.assertIsInstance(builder.get("data"), DataChunkIterator) def test_build_dataio_datachunkiterator(self): # hdmf#512 """Test building of a dataset with no dtype and no data_type with value DataIO wrapping a DCI.""" - container = Baz('my_baz', H5DataIO(DataChunkIterator(['a', 'b', 'c', 'd']), chunks=True), 'value1') + container = Baz( + "my_baz", + H5DataIO(DataChunkIterator(["a", "b", "c", "d"]), chunks=True), + "value1", + ) builder = self.type_map.build(container) - self.assertIsInstance(builder.get('data'), H5DataIO) - self.assertIsInstance(builder.get('data').data, DataChunkIterator) + self.assertIsInstance(builder.get("data"), H5DataIO) + self.assertIsInstance(builder.get("data").data, DataChunkIterator) diff --git a/tests/unit/common/test_alignedtable.py b/tests/unit/common/test_alignedtable.py index f334aff27..f8706c941 100644 --- a/tests/unit/common/test_alignedtable.py +++ b/tests/unit/common/test_alignedtable.py @@ -1,9 +1,16 @@ +import warnings + import numpy as np from pandas.testing import assert_frame_equal -import warnings from hdmf.backends.hdf5 import HDF5IO -from hdmf.common import DynamicTable, VectorData, get_manager, AlignedDynamicTable, DynamicTableRegion +from hdmf.common import ( + AlignedDynamicTable, + DynamicTable, + DynamicTableRegion, + VectorData, + get_manager, +) from hdmf.testing import TestCase, remove_test_file @@ -17,220 +24,345 @@ class TestAlignedDynamicTableContainer(TestCase): * get_linked_tables methods are tested in the test_linkedtables.TestLinkedAlignedDynamicTables class instead of here. """ + def setUp(self): warnings.simplefilter("always") # Trigger all warnings - self.path = 'test_icephys_meta_intracellularrecording.h5' + self.path = "test_icephys_meta_intracellularrecording.h5" def tearDown(self): remove_test_file(self.path) def test_init(self): """Test that just checks that populating the tables with data works correctly""" - AlignedDynamicTable( - name='test_aligned_table', - description='Test aligned container') + AlignedDynamicTable(name="test_aligned_table", description="Test aligned container") def test_init_categories_without_category_tables_error(self): # Test raise error if categories is given without category_tables with self.assertRaisesWith(ValueError, "Categories provided but no category_tables given"): AlignedDynamicTable( - name='test_aligned_table', - description='Test aligned container', - categories=['cat1', 'cat2']) + name="test_aligned_table", + description="Test aligned container", + categories=["cat1", "cat2"], + ) def test_init_length_mismatch_between_categories_and_category_tables(self): # Test length mismatch between categories and category_tables - with self.assertRaisesWith(ValueError, "0 category_tables given but 2 categories specified"): + with self.assertRaisesWith(ValueError, "0 category_tables given but 2 categories specified"): AlignedDynamicTable( - name='test_aligned_table', - description='Test aligned container', - categories=['cat1', 'cat2'], - category_tables=[]) + name="test_aligned_table", + description="Test aligned container", + categories=["cat1", "cat2"], + category_tables=[], + ) def test_init_category_table_names_do_not_match_categories(self): # Construct some categories for testing - category_names = ['test1', 'test2', 'test3'] + category_names = ["test1", "test2", "test3"] num_rows = 10 - categories = [DynamicTable(name=val, - description=val+" description", - columns=[VectorData(name=val+t, - description=val+t+' description', - data=np.arange(num_rows)) for t in ['c1', 'c2', 'c3']] - ) for val in category_names] + categories = [ + DynamicTable( + name=val, + description=val + " description", + columns=[ + VectorData( + name=val + t, + description=val + t + " description", + data=np.arange(num_rows), + ) + for t in ["c1", "c2", "c3"] + ], + ) + for val in category_names + ] # Test add category_table that is not listed in the categories list - with self.assertRaisesWith(ValueError, - "DynamicTable test3 does not appear in categories ['test1', 'test2', 't3']"): + with self.assertRaisesWith( + ValueError, + "DynamicTable test3 does not appear in categories ['test1', 'test2', 't3']", + ): AlignedDynamicTable( - name='test_aligned_table', - description='Test aligned container', - categories=['test1', 'test2', 't3'], # bad name for 'test3' - category_tables=categories) + name="test_aligned_table", + description="Test aligned container", + categories=["test1", "test2", "t3"], # bad name for 'test3' + category_tables=categories, + ) def test_init_duplicate_category_table_name(self): # Test duplicate table name - with self.assertRaisesWith(ValueError, "Duplicate table name test1 found in input dynamic_tables"): - categories = [DynamicTable(name=val, - description=val+" description", - columns=[VectorData(name=val+t, - description=val+t+' description', - data=np.arange(10)) for t in ['c1', 'c2', 'c3']] - ) for val in ['test1', 'test1', 'test3']] + with self.assertRaisesWith( + ValueError, + "Duplicate table name test1 found in input dynamic_tables", + ): + categories = [ + DynamicTable( + name=val, + description=val + " description", + columns=[ + VectorData( + name=val + t, + description=val + t + " description", + data=np.arange(10), + ) + for t in ["c1", "c2", "c3"] + ], + ) + for val in ["test1", "test1", "test3"] + ] AlignedDynamicTable( - name='test_aligned_table', - description='Test aligned container', - categories=['test1', 'test2', 'test3'], - category_tables=categories) + name="test_aligned_table", + description="Test aligned container", + categories=["test1", "test2", "test3"], + category_tables=categories, + ) def test_init_misaligned_category_tables(self): """Test misaligned category tables""" - categories = [DynamicTable(name=val, - description=val+" description", - columns=[VectorData(name=val+t, - description=val+t+' description', - data=np.arange(10)) for t in ['c1', 'c2', 'c3']] - ) for val in ['test1', 'test2']] - categories.append(DynamicTable(name='test3', - description="test3 description", - columns=[VectorData(name='test3 '+t, - description='test3 '+t+' description', - data=np.arange(8)) for t in ['c1', 'c2', 'c3']])) - with self.assertRaisesWith(ValueError, - "Category DynamicTable test3 does not align, it has 8 rows expected 10"): + categories = [ + DynamicTable( + name=val, + description=val + " description", + columns=[ + VectorData( + name=val + t, + description=val + t + " description", + data=np.arange(10), + ) + for t in ["c1", "c2", "c3"] + ], + ) + for val in ["test1", "test2"] + ] + categories.append( + DynamicTable( + name="test3", + description="test3 description", + columns=[ + VectorData( + name="test3 " + t, + description="test3 " + t + " description", + data=np.arange(8), + ) + for t in ["c1", "c2", "c3"] + ], + ) + ) + with self.assertRaisesWith( + ValueError, + "Category DynamicTable test3 does not align, it has 8 rows expected 10", + ): AlignedDynamicTable( - name='test_aligned_table', - description='Test aligned container', - categories=['test1', 'test2', 'test3'], - category_tables=categories) + name="test_aligned_table", + description="Test aligned container", + categories=["test1", "test2", "test3"], + category_tables=categories, + ) def test_init_with_custom_empty_categories(self): """Test that we can create an empty table with custom categories""" - category_names = ['test1', 'test2', 'test3'] - categories = [DynamicTable(name=val, description=val+" description") for val in category_names] + category_names = ["test1", "test2", "test3"] + categories = [DynamicTable(name=val, description=val + " description") for val in category_names] AlignedDynamicTable( - name='test_aligned_table', - description='Test aligned container', - category_tables=categories) + name="test_aligned_table", + description="Test aligned container", + category_tables=categories, + ) def test_init_with_custom_nonempty_categories(self): """Test that we can create an empty table with custom categories""" - category_names = ['test1', 'test2', 'test3'] + category_names = ["test1", "test2", "test3"] num_rows = 10 - categories = [DynamicTable(name=val, - description=val+" description", - columns=[VectorData(name=val+t, - description=val+t+' description', - data=np.arange(num_rows)) for t in ['c1', 'c2', 'c3']] - ) for val in category_names] + categories = [ + DynamicTable( + name=val, + description=val + " description", + columns=[ + VectorData( + name=val + t, + description=val + t + " description", + data=np.arange(num_rows), + ) + for t in ["c1", "c2", "c3"] + ], + ) + for val in category_names + ] temp = AlignedDynamicTable( - name='test_aligned_table', - description='Test aligned container', - category_tables=categories) + name="test_aligned_table", + description="Test aligned container", + category_tables=categories, + ) self.assertEqual(temp.categories, category_names) def test_init_with_custom_nonempty_categories_and_main(self): """ Test that we can create a non-empty table with custom non-empty categories """ - category_names = ['test1', 'test2', 'test3'] + category_names = ["test1", "test2", "test3"] num_rows = 10 - categories = [DynamicTable(name=val, - description=val+" description", - columns=[VectorData(name=t, - description=val+t+' description', - data=np.arange(num_rows)) for t in ['c1', 'c2', 'c3']] - ) for val in category_names] + categories = [ + DynamicTable( + name=val, + description=val + " description", + columns=[ + VectorData( + name=t, + description=val + t + " description", + data=np.arange(num_rows), + ) + for t in ["c1", "c2", "c3"] + ], + ) + for val in category_names + ] temp = AlignedDynamicTable( - name='test_aligned_table', - description='Test aligned container', + name="test_aligned_table", + description="Test aligned container", category_tables=categories, - columns=[VectorData(name='main_' + t, - description='main_'+t+'_description', - data=np.arange(num_rows)) for t in ['c1', 'c2', 'c3']]) + columns=[ + VectorData( + name="main_" + t, + description="main_" + t + "_description", + data=np.arange(num_rows), + ) + for t in ["c1", "c2", "c3"] + ], + ) self.assertEqual(temp.categories, category_names) - self.assertTrue('test1' in temp) # test that contains category works - self.assertTrue(('test1', 'c1') in temp) # test that contains a column works + self.assertTrue("test1" in temp) # test that contains category works + self.assertTrue(("test1", "c1") in temp) # test that contains a column works # test the error case of a tuple with len !=2 - with self.assertRaisesWith(ValueError, "Expected tuple of strings of length 2 got tuple of length 3"): - ('test1', 'c1', 't3') in temp - self.assertTupleEqual(temp.colnames, ('main_c1', 'main_c2', 'main_c3')) # confirm column names + with self.assertRaisesWith( + ValueError, + "Expected tuple of strings of length 2 got tuple of length 3", + ): + ("test1", "c1", "t3") in temp + self.assertTupleEqual(temp.colnames, ("main_c1", "main_c2", "main_c3")) # confirm column names def test_init_with_custom_misaligned_categories(self): """Test that we cannot create an empty table with custom categories""" num_rows = 10 - val1 = 'test1' - val2 = 'test2' - categories = [DynamicTable(name=val1, - description=val1+" description", - columns=[VectorData(name=val1+t, - description=val1+t+' description', - data=np.arange(num_rows)) for t in ['c1', 'c2', 'c3']]), - DynamicTable(name=val2, - description=val2+" description", - columns=[VectorData(name=val2+t, - description=val2+t+' description', - data=np.arange(num_rows+1)) for t in ['c1', 'c2', 'c3']]) - ] - with self.assertRaisesWith(ValueError, - "Category DynamicTable test2 does not align, it has 11 rows expected 10"): + val1 = "test1" + val2 = "test2" + categories = [ + DynamicTable( + name=val1, + description=val1 + " description", + columns=[ + VectorData( + name=val1 + t, + description=val1 + t + " description", + data=np.arange(num_rows), + ) + for t in ["c1", "c2", "c3"] + ], + ), + DynamicTable( + name=val2, + description=val2 + " description", + columns=[ + VectorData( + name=val2 + t, + description=val2 + t + " description", + data=np.arange(num_rows + 1), + ) + for t in ["c1", "c2", "c3"] + ], + ), + ] + with self.assertRaisesWith( + ValueError, + "Category DynamicTable test2 does not align, it has 11 rows expected 10", + ): AlignedDynamicTable( - name='test_aligned_table', - description='Test aligned container', - category_tables=categories) + name="test_aligned_table", + description="Test aligned container", + category_tables=categories, + ) def test_init_with_duplicate_custom_categories(self): """Test that we can create an empty table with custom categories""" - category_names = ['test1', 'test1'] + category_names = ["test1", "test1"] num_rows = 10 - categories = [DynamicTable(name=val, - description=val+" description", - columns=[VectorData(name=val+t, - description=val+t+' description', - data=np.arange(num_rows)) for t in ['c1', 'c2', 'c3']] - ) for val in category_names] - with self.assertRaisesWith(ValueError, "Duplicate table name test1 found in input dynamic_tables"): + categories = [ + DynamicTable( + name=val, + description=val + " description", + columns=[ + VectorData( + name=val + t, + description=val + t + " description", + data=np.arange(num_rows), + ) + for t in ["c1", "c2", "c3"] + ], + ) + for val in category_names + ] + with self.assertRaisesWith( + ValueError, + "Duplicate table name test1 found in input dynamic_tables", + ): AlignedDynamicTable( - name='test_aligned_table', - description='Test aligned container', - category_tables=categories) + name="test_aligned_table", + description="Test aligned container", + category_tables=categories, + ) def test_init_with_bad_custom_categories(self): """Test that we cannot provide a category that is not a DynamicTable""" num_rows = 10 categories = [ # good category - DynamicTable(name='test1', - description="test1 description", - columns=[VectorData(name='test1'+t, - description='test1' + t + ' description', - data=np.arange(num_rows)) for t in ['c1', 'c2', 'c3']] - ), - # use a list as a bad category example - [0, 1, 2]] + DynamicTable( + name="test1", + description="test1 description", + columns=[ + VectorData( + name="test1" + t, + description="test1" + t + " description", + data=np.arange(num_rows), + ) + for t in ["c1", "c2", "c3"] + ], + ), + # use a list as a bad category example + [0, 1, 2], + ] with self.assertRaisesWith(ValueError, "Category table with index 1 is not a DynamicTable"): AlignedDynamicTable( - name='test_aligned_table', - description='Test aligned container', - category_tables=categories) + name="test_aligned_table", + description="Test aligned container", + category_tables=categories, + ) def test_round_trip_container(self): """Test read and write the container by itself""" - category_names = ['test1', 'test2', 'test3'] + category_names = ["test1", "test2", "test3"] num_rows = 10 - categories = [DynamicTable(name=val, - description=val+" description", - columns=[VectorData(name=t, - description=val+t+' description', - data=np.arange(num_rows)) for t in ['c1', 'c2', 'c3']] - ) for val in category_names] + categories = [ + DynamicTable( + name=val, + description=val + " description", + columns=[ + VectorData( + name=t, + description=val + t + " description", + data=np.arange(num_rows), + ) + for t in ["c1", "c2", "c3"] + ], + ) + for val in category_names + ] curr = AlignedDynamicTable( - name='test_aligned_table', - description='Test aligned container', - category_tables=categories) + name="test_aligned_table", + description="Test aligned container", + category_tables=categories, + ) - with HDF5IO(self.path, manager=get_manager(), mode='w') as io: + with HDF5IO(self.path, manager=get_manager(), mode="w") as io: io.write(curr) - with HDF5IO(self.path, manager=get_manager(), mode='r') as io: + with HDF5IO(self.path, manager=get_manager(), mode="r") as io: incon = io.read() self.assertListEqual(incon.categories, curr.categories) for n in category_names: @@ -238,132 +370,231 @@ def test_round_trip_container(self): def test_add_category(self): """Test that we can correct a non-empty category to an existing table""" - category_names = ['test1', 'test2', 'test3'] + category_names = ["test1", "test2", "test3"] num_rows = 10 - categories = [DynamicTable(name=val, - description=val+" description", - columns=[VectorData(name=val+t, - description=val+t+' description', - data=np.arange(num_rows)) for t in ['c1', 'c2', 'c3']] - ) for val in category_names] + categories = [ + DynamicTable( + name=val, + description=val + " description", + columns=[ + VectorData( + name=val + t, + description=val + t + " description", + data=np.arange(num_rows), + ) + for t in ["c1", "c2", "c3"] + ], + ) + for val in category_names + ] adt = AlignedDynamicTable( - name='test_aligned_table', - description='Test aligned container', - category_tables=categories[0:2]) + name="test_aligned_table", + description="Test aligned container", + category_tables=categories[0:2], + ) self.assertListEqual(adt.categories, category_names[0:2]) adt.add_category(categories[-1]) self.assertListEqual(adt.categories, category_names) def test_add_category_misaligned_rows(self): """Test that we can correct a non-empty category to an existing table""" - category_names = ['test1', 'test2'] + category_names = ["test1", "test2"] num_rows = 10 - categories = [DynamicTable(name=val, - description=val+" description", - columns=[VectorData(name=val+t, - description=val+t+' description', - data=np.arange(num_rows)) for t in ['c1', 'c2', 'c3']] - ) for val in category_names] + categories = [ + DynamicTable( + name=val, + description=val + " description", + columns=[ + VectorData( + name=val + t, + description=val + t + " description", + data=np.arange(num_rows), + ) + for t in ["c1", "c2", "c3"] + ], + ) + for val in category_names + ] adt = AlignedDynamicTable( - name='test_aligned_table', - description='Test aligned container', - category_tables=categories) + name="test_aligned_table", + description="Test aligned container", + category_tables=categories, + ) self.assertListEqual(adt.categories, category_names) - with self.assertRaisesWith(ValueError, "New category DynamicTable does not align, it has 8 rows expected 10"): - adt.add_category(DynamicTable(name='test3', - description='test3_description', - columns=[VectorData(name='test3_'+t, - description='test3 '+t+' description', - data=np.arange(num_rows - 2)) for t in ['c1', 'c2', 'c3'] - ])) + with self.assertRaisesWith( + ValueError, + "New category DynamicTable does not align, it has 8 rows expected 10", + ): + adt.add_category( + DynamicTable( + name="test3", + description="test3_description", + columns=[ + VectorData( + name="test3_" + t, + description="test3 " + t + " description", + data=np.arange(num_rows - 2), + ) + for t in ["c1", "c2", "c3"] + ], + ) + ) def test_add_category_already_in_table(self): - category_names = ['test1', 'test2', 'test2'] + category_names = ["test1", "test2", "test2"] num_rows = 10 - categories = [DynamicTable(name=val, - description=val+" description", - columns=[VectorData(name=val+t, - description=val+t+' description', - data=np.arange(num_rows)) for t in ['c1', 'c2', 'c3']] - ) for val in category_names] + categories = [ + DynamicTable( + name=val, + description=val + " description", + columns=[ + VectorData( + name=val + t, + description=val + t + " description", + data=np.arange(num_rows), + ) + for t in ["c1", "c2", "c3"] + ], + ) + for val in category_names + ] adt = AlignedDynamicTable( - name='test_aligned_table', - description='Test aligned container', - category_tables=categories[0:2]) + name="test_aligned_table", + description="Test aligned container", + category_tables=categories[0:2], + ) self.assertListEqual(adt.categories, category_names[0:2]) with self.assertRaisesWith(ValueError, "Category test2 already in the table"): adt.add_category(categories[-1]) def test_add_column(self): adt = AlignedDynamicTable( - name='test_aligned_table', - description='Test aligned container', - columns=[VectorData(name='test_'+t, - description='test_'+t+' description', - data=np.arange(10)) for t in ['c1', 'c2', 'c3']]) + name="test_aligned_table", + description="Test aligned container", + columns=[ + VectorData( + name="test_" + t, + description="test_" + t + " description", + data=np.arange(10), + ) + for t in ["c1", "c2", "c3"] + ], + ) # Test successful add - adt.add_column(name='testA', description='testA', data=np.arange(10)) - self.assertTupleEqual(adt.colnames, ('test_c1', 'test_c2', 'test_c3', 'testA')) + adt.add_column(name="testA", description="testA", data=np.arange(10)) + self.assertTupleEqual(adt.colnames, ("test_c1", "test_c2", "test_c3", "testA")) def test_add_column_bad_category(self): """Test add column with bad category""" adt = AlignedDynamicTable( - name='test_aligned_table', - description='Test aligned container', - columns=[VectorData(name='test_'+t, - description='test_'+t+' description', - data=np.arange(10)) for t in ['c1', 'c2', 'c3']]) + name="test_aligned_table", + description="Test aligned container", + columns=[ + VectorData( + name="test_" + t, + description="test_" + t + " description", + data=np.arange(10), + ) + for t in ["c1", "c2", "c3"] + ], + ) with self.assertRaisesWith(KeyError, "'Category mycat not in table'"): - adt.add_column(category='mycat', name='testA', description='testA', data=np.arange(10)) + adt.add_column( + category="mycat", + name="testA", + description="testA", + data=np.arange(10), + ) def test_add_column_bad_length(self): """Test add column that is too short""" adt = AlignedDynamicTable( - name='test_aligned_table', - description='Test aligned container', - columns=[VectorData(name='test_'+t, - description='test_'+t+' description', - data=np.arange(10)) for t in ['c1', 'c2', 'c3']]) + name="test_aligned_table", + description="Test aligned container", + columns=[ + VectorData( + name="test_" + t, + description="test_" + t + " description", + data=np.arange(10), + ) + for t in ["c1", "c2", "c3"] + ], + ) # Test successful add with self.assertRaisesWith(ValueError, "column must have the same number of rows as 'id'"): - adt.add_column(name='testA', description='testA', data=np.arange(8)) + adt.add_column(name="testA", description="testA", data=np.arange(8)) def test_add_column_to_subcategory(self): """Test adding a column to a subcategory""" - category_names = ['test1', 'test2', 'test3'] + category_names = ["test1", "test2", "test3"] num_rows = 10 - categories = [DynamicTable(name=val, - description=val+" description", - columns=[VectorData(name=val+t, - description=val+t+' description', - data=np.arange(num_rows)) for t in ['c1', 'c2', 'c3']] - ) for val in category_names] + categories = [ + DynamicTable( + name=val, + description=val + " description", + columns=[ + VectorData( + name=val + t, + description=val + t + " description", + data=np.arange(num_rows), + ) + for t in ["c1", "c2", "c3"] + ], + ) + for val in category_names + ] adt = AlignedDynamicTable( - name='test_aligned_table', - description='Test aligned container', - category_tables=categories) + name="test_aligned_table", + description="Test aligned container", + category_tables=categories, + ) self.assertListEqual(adt.categories, category_names) # Test successful add - adt.add_column(category='test2', name='testA', description='testA', data=np.arange(10)) - self.assertTupleEqual(adt.get_category('test2').colnames, ('test2c1', 'test2c2', 'test2c3', 'testA')) + adt.add_column( + category="test2", + name="testA", + description="testA", + data=np.arange(10), + ) + self.assertTupleEqual( + adt.get_category("test2").colnames, + ("test2c1", "test2c2", "test2c3", "testA"), + ) def test_add_row(self): """Test adding a row to a non_empty table""" - category_names = ['test1', ] + category_names = [ + "test1", + ] num_rows = 10 - categories = [DynamicTable(name=val, - description=val+" description", - columns=[VectorData(name=t, - description=val+t+' description', - data=np.arange(num_rows)) for t in ['c1', 'c2']] - ) for val in category_names] + categories = [ + DynamicTable( + name=val, + description=val + " description", + columns=[ + VectorData( + name=t, + description=val + t + " description", + data=np.arange(num_rows), + ) + for t in ["c1", "c2"] + ], + ) + for val in category_names + ] temp = AlignedDynamicTable( - name='test_aligned_table', - description='Test aligned container', + name="test_aligned_table", + description="Test aligned container", category_tables=categories, - columns=[VectorData(name='main_' + t, - description='main_'+t+'_description', - data=np.arange(num_rows)) for t in ['c1', 'c2']]) + columns=[ + VectorData( + name="main_" + t, + description="main_" + t + "_description", + data=np.arange(num_rows), + ) + for t in ["c1", "c2"] + ], + ) self.assertListEqual(temp.categories, category_names) # Test successful add temp.add_row(test1=dict(c1=1, c2=2), main_c1=3, main_c2=5) @@ -378,22 +609,38 @@ def test_add_row(self): def test_get_item(self): """Test getting elements from the table""" - category_names = ['test1', ] + category_names = [ + "test1", + ] num_rows = 10 - categories = [DynamicTable(name=val, - description=val+" description", - columns=[VectorData(name=t, - description=val+t+' description', - data=np.arange(num_rows) + i + 3) - for i, t in enumerate(['c1', 'c2'])] - ) for val in category_names] + categories = [ + DynamicTable( + name=val, + description=val + " description", + columns=[ + VectorData( + name=t, + description=val + t + " description", + data=np.arange(num_rows) + i + 3, + ) + for i, t in enumerate(["c1", "c2"]) + ], + ) + for val in category_names + ] temp = AlignedDynamicTable( - name='test_aligned_table', - description='Test aligned container', + name="test_aligned_table", + description="Test aligned container", category_tables=categories, - columns=[VectorData(name='main_' + t, - description='main_'+t+'_description', - data=np.arange(num_rows)+2) for t in ['c1', 'c2']]) + columns=[ + VectorData( + name="main_" + t, + description="main_" + t + "_description", + data=np.arange(num_rows) + 2, + ) + for t in ["c1", "c2"] + ], + ) self.assertListEqual(temp.categories, category_names) # Test slicing with a single index self.assertListEqual(temp[5].iloc[0].tolist(), [7, 7, 5, 8, 9]) @@ -407,60 +654,90 @@ def test_get_item(self): self.assertListEqual(temp[np.asarray([5, 8])].iloc[0].tolist(), [7, 7, 5, 8, 9]) self.assertListEqual(temp[np.asarray([5, 8])].iloc[1].tolist(), [10, 10, 8, 11, 12]) # Test slicing for a single column - self.assertListEqual(temp['main_c1'][:].tolist(), (np.arange(num_rows)+2).tolist()) + self.assertListEqual(temp["main_c1"][:].tolist(), (np.arange(num_rows) + 2).tolist()) # Test slicing for a single category - assert_frame_equal(temp['test1'], categories[0].to_dataframe()) + assert_frame_equal(temp["test1"], categories[0].to_dataframe()) # Test getting the main table assert_frame_equal(temp[None], temp.to_dataframe()) # Test getting a specific column - self.assertListEqual(temp['test1', 'c1'][:].tolist(), (np.arange(num_rows) + 3).tolist()) + self.assertListEqual(temp["test1", "c1"][:].tolist(), (np.arange(num_rows) + 3).tolist()) # Test getting a specific cell - self.assertEqual(temp[None, 'main_c1', 1], 3) - self.assertEqual(temp[1, None, 'main_c1'], 3) + self.assertEqual(temp[None, "main_c1", 1], 3) + self.assertEqual(temp[1, None, "main_c1"], 3) # Test bad selection tuple - with self.assertRaisesWith(ValueError, - "Expected tuple of length 2 of the form [category, column], [row, category], " - "[row, (category, column)] or a tuple of length 3 of the form " - "[category, column, row], [row, category, column]"): - temp[('main_c1',)] + with self.assertRaisesWith( + ValueError, + ( + "Expected tuple of length 2 of the form [category, column], [row," + " category], [row, (category, column)] or a tuple of length 3 of the" + " form [category, column, row], [row, category, column]" + ), + ): + temp[("main_c1",)] # Test selecting a single cell or row of a category table by having a # [int, str] or [int, (str, str)] type selection # Select row 0 from category 'test1' - re = temp[0, 'test1'] - self.assertListEqual(re.columns.to_list(), ['id', 'c1', 'c2']) - self.assertListEqual(re.index.names, [('test_aligned_table', 'id')]) + re = temp[0, "test1"] + self.assertListEqual(re.columns.to_list(), ["id", "c1", "c2"]) + self.assertListEqual(re.index.names, [("test_aligned_table", "id")]) self.assertListEqual(re.values.tolist()[0], [0, 3, 4]) # Select a single cell from a column - self.assertEqual(temp[1, ('test_aligned_table', 'main_c1')], 3) + self.assertEqual(temp[1, ("test_aligned_table", "main_c1")], 3) def test_to_dataframe(self): """Test that the to_dataframe method works""" - category_names = ['test1', 'test2', 'test3'] + category_names = ["test1", "test2", "test3"] num_rows = 10 - categories = [DynamicTable(name=val, - description=val+" description", - columns=[VectorData(name=t, - description=val+t+' description', - data=np.arange(num_rows)) for t in ['c1', 'c2', 'c3']] - ) for val in category_names] + categories = [ + DynamicTable( + name=val, + description=val + " description", + columns=[ + VectorData( + name=t, + description=val + t + " description", + data=np.arange(num_rows), + ) + for t in ["c1", "c2", "c3"] + ], + ) + for val in category_names + ] adt = AlignedDynamicTable( - name='test_aligned_table', - description='Test aligned container', + name="test_aligned_table", + description="Test aligned container", category_tables=categories, - columns=[VectorData(name='main_' + t, - description='main_'+t+'_description', - data=np.arange(num_rows)) for t in ['c1', 'c2', 'c3']]) + columns=[ + VectorData( + name="main_" + t, + description="main_" + t + "_description", + data=np.arange(num_rows), + ) + for t in ["c1", "c2", "c3"] + ], + ) # Test the to_dataframe method with default settings tdf = adt.to_dataframe() self.assertListEqual(tdf.index.tolist(), list(range(10))) - self.assertTupleEqual(tdf.index.name, ('test_aligned_table', 'id')) - expected_cols = [('test_aligned_table', 'main_c1'), - ('test_aligned_table', 'main_c2'), - ('test_aligned_table', 'main_c3'), - ('test1', 'id'), ('test1', 'c1'), ('test1', 'c2'), ('test1', 'c3'), - ('test2', 'id'), ('test2', 'c1'), ('test2', 'c2'), ('test2', 'c3'), - ('test3', 'id'), ('test3', 'c1'), ('test3', 'c2'), ('test3', 'c3')] + self.assertTupleEqual(tdf.index.name, ("test_aligned_table", "id")) + expected_cols = [ + ("test_aligned_table", "main_c1"), + ("test_aligned_table", "main_c2"), + ("test_aligned_table", "main_c3"), + ("test1", "id"), + ("test1", "c1"), + ("test1", "c2"), + ("test1", "c3"), + ("test2", "id"), + ("test2", "c1"), + ("test2", "c2"), + ("test2", "c3"), + ("test3", "id"), + ("test3", "c1"), + ("test3", "c2"), + ("test3", "c3"), + ] tdf_cols = tdf.columns.tolist() for v in zip(expected_cols, tdf_cols): self.assertTupleEqual(v[0], v[1]) @@ -468,13 +745,21 @@ def test_to_dataframe(self): # test the to_dataframe method with ignore_category_ids set to True tdf = adt.to_dataframe(ignore_category_ids=True) self.assertListEqual(tdf.index.tolist(), list(range(10))) - self.assertTupleEqual(tdf.index.name, ('test_aligned_table', 'id')) - expected_cols = [('test_aligned_table', 'main_c1'), - ('test_aligned_table', 'main_c2'), - ('test_aligned_table', 'main_c3'), - ('test1', 'c1'), ('test1', 'c2'), ('test1', 'c3'), - ('test2', 'c1'), ('test2', 'c2'), ('test2', 'c3'), - ('test3', 'c1'), ('test3', 'c2'), ('test3', 'c3')] + self.assertTupleEqual(tdf.index.name, ("test_aligned_table", "id")) + expected_cols = [ + ("test_aligned_table", "main_c1"), + ("test_aligned_table", "main_c2"), + ("test_aligned_table", "main_c3"), + ("test1", "c1"), + ("test1", "c2"), + ("test1", "c3"), + ("test2", "c1"), + ("test2", "c2"), + ("test2", "c3"), + ("test3", "c1"), + ("test3", "c2"), + ("test3", "c3"), + ] tdf_cols = tdf.columns.tolist() for v in zip(expected_cols, tdf_cols): self.assertTupleEqual(v[0], v[1]) @@ -484,34 +769,67 @@ def test_nested_aligned_dynamic_table_not_allowed(self): Test that using and AlignedDynamicTable as category for an AlignedDynamicTable is not allowed """ # create an AlignedDynamicTable as category - subsubcol1 = VectorData(name='sub_sub_column1', description='test sub sub column', data=['test11', 'test12']) - sub_category = DynamicTable(name='sub_category1', description='test subcategory table', columns=[subsubcol1, ]) - subcol1 = VectorData(name='sub_column1', description='test-subcolumn', data=['test1', 'test2']) + subsubcol1 = VectorData( + name="sub_sub_column1", + description="test sub sub column", + data=["test11", "test12"], + ) + sub_category = DynamicTable( + name="sub_category1", + description="test subcategory table", + columns=[ + subsubcol1, + ], + ) + subcol1 = VectorData( + name="sub_column1", + description="test-subcolumn", + data=["test1", "test2"], + ) adt_category = AlignedDynamicTable( - name='category1', - description='test using AlignedDynamicTable as a category', - columns=[subcol1, ], - category_tables=[sub_category, ]) + name="category1", + description="test using AlignedDynamicTable as a category", + columns=[ + subcol1, + ], + category_tables=[ + sub_category, + ], + ) # Create a regular column for our main AlignedDynamicTable - col1 = VectorData(name='column1', description='regular test column', data=['test1', 'test2']) + col1 = VectorData( + name="column1", + description="regular test column", + data=["test1", "test2"], + ) # test 1: Make sure we can't add the AlignedDynamicTable category on init - msg = ("Category table with index %i is an AlignedDynamicTable. " - "Nesting of AlignedDynamicTable is currently not supported." % 0) + msg = ( + "Category table with index %i is an AlignedDynamicTable. " + "Nesting of AlignedDynamicTable is currently not supported." % 0 + ) with self.assertRaisesWith(ValueError, msg): # create the nested AlignedDynamicTable with our adt_category as a sub-category AlignedDynamicTable( - name='nested_adt', - description='test nesting AlignedDynamicTable', - columns=[col1, ], - category_tables=[adt_category, ]) + name="nested_adt", + description="test nesting AlignedDynamicTable", + columns=[ + col1, + ], + category_tables=[ + adt_category, + ], + ) # test 2: Make sure we can't add the AlignedDynamicTable category via add_category adt = AlignedDynamicTable( - name='nested_adt', - description='test nesting AlignedDynamicTable', - columns=[col1, ]) + name="nested_adt", + description="test nesting AlignedDynamicTable", + columns=[ + col1, + ], + ) msg = "Category is an AlignedDynamicTable. Nesting of AlignedDynamicTable is currently not supported." with self.assertRaisesWith(ValueError, msg): adt.add_category(adt_category) @@ -522,25 +840,46 @@ def test_dynamictable_region_to_aligneddynamictable(self): In particular, make sure that all columns are being used, including those of the category tables, not just the ones from the main table. """ - temp_table = DynamicTable(name='t1', description='t1', - colnames=['c1', 'c2'], - columns=[VectorData(name='c1', description='c1', data=np.arange(4)), - VectorData(name='c2', description='c2', data=np.arange(4))]) - temp_aligned_table = AlignedDynamicTable(name='my_aligned_table', - description='my test table', - category_tables=[temp_table], - colnames=['a1', 'a2'], - columns=[VectorData(name='a1', description='c1', data=np.arange(4)), - VectorData(name='a2', description='c1', data=np.arange(4))]) - dtr = DynamicTableRegion(name='test', description='test', data=np.arange(4), table=temp_aligned_table) + temp_table = DynamicTable( + name="t1", + description="t1", + colnames=["c1", "c2"], + columns=[ + VectorData(name="c1", description="c1", data=np.arange(4)), + VectorData(name="c2", description="c2", data=np.arange(4)), + ], + ) + temp_aligned_table = AlignedDynamicTable( + name="my_aligned_table", + description="my test table", + category_tables=[temp_table], + colnames=["a1", "a2"], + columns=[ + VectorData(name="a1", description="c1", data=np.arange(4)), + VectorData(name="a2", description="c1", data=np.arange(4)), + ], + ) + dtr = DynamicTableRegion( + name="test", + description="test", + data=np.arange(4), + table=temp_aligned_table, + ) dtr_df = dtr[:] # Full number of rows self.assertEqual(len(dtr_df), 4) # Test num columns: 2 columns from the main table, 2 columns from the category, 1 id columns from the category self.assertEqual(len(dtr_df.columns), 5) # Test that the data is correct - for i, v in enumerate([('my_aligned_table', 'a1'), ('my_aligned_table', 'a2'), - ('t1', 'id'), ('t1', 'c1'), ('t1', 'c2')]): + for i, v in enumerate( + [ + ("my_aligned_table", "a1"), + ("my_aligned_table", "a2"), + ("t1", "id"), + ("t1", "c1"), + ("t1", "c2"), + ] + ): self.assertTupleEqual(dtr_df.columns[i], v) # Test the column data for c in dtr_df.columns: @@ -550,39 +889,82 @@ def test_get_colnames(self): """ Test the AlignedDynamicTable.get_colnames function """ - category_names = ['test1', 'test2', 'test3'] + category_names = ["test1", "test2", "test3"] num_rows = 10 - categories = [DynamicTable(name=val, - description=val+" description", - columns=[VectorData(name=t, - description=val+t+' description', - data=np.arange(num_rows)) for t in ['c1', 'c2', 'c3']] - ) for val in category_names] + categories = [ + DynamicTable( + name=val, + description=val + " description", + columns=[ + VectorData( + name=t, + description=val + t + " description", + data=np.arange(num_rows), + ) + for t in ["c1", "c2", "c3"] + ], + ) + for val in category_names + ] adt = AlignedDynamicTable( - name='test_aligned_table', - description='Test aligned container', + name="test_aligned_table", + description="Test aligned container", category_tables=categories, - columns=[VectorData(name='main_' + t, - description='main_'+t+'_description', - data=np.arange(num_rows)) for t in ['c1', 'c2', 'c3']]) + columns=[ + VectorData( + name="main_" + t, + description="main_" + t + "_description", + data=np.arange(num_rows), + ) + for t in ["c1", "c2", "c3"] + ], + ) # Default, only get the colnames of the main table. Same as adt.colnames property - expected_colnames = ('main_c1', 'main_c2', 'main_c3') + expected_colnames = ("main_c1", "main_c2", "main_c3") self.assertTupleEqual(adt.get_colnames(), expected_colnames) # Same as default because if we don't include the categories than ignore_category_ids has no effect - self.assertTupleEqual(adt.get_colnames(include_category_tables=False, ignore_category_ids=True), - expected_colnames) + self.assertTupleEqual( + adt.get_colnames(include_category_tables=False, ignore_category_ids=True), + expected_colnames, + ) # Full set of columns - expected_colnames = [('test_aligned_table', 'main_c1'), ('test_aligned_table', 'main_c2'), - ('test_aligned_table', 'main_c3'), ('test1', 'id'), ('test1', 'c1'), - ('test1', 'c2'), ('test1', 'c3'), ('test2', 'id'), ('test2', 'c1'), - ('test2', 'c2'), ('test2', 'c3'), ('test3', 'id'), ('test3', 'c1'), - ('test3', 'c2'), ('test3', 'c3')] - self.assertListEqual(adt.get_colnames(include_category_tables=True, ignore_category_ids=False), - expected_colnames) + expected_colnames = [ + ("test_aligned_table", "main_c1"), + ("test_aligned_table", "main_c2"), + ("test_aligned_table", "main_c3"), + ("test1", "id"), + ("test1", "c1"), + ("test1", "c2"), + ("test1", "c3"), + ("test2", "id"), + ("test2", "c1"), + ("test2", "c2"), + ("test2", "c3"), + ("test3", "id"), + ("test3", "c1"), + ("test3", "c2"), + ("test3", "c3"), + ] + self.assertListEqual( + adt.get_colnames(include_category_tables=True, ignore_category_ids=False), + expected_colnames, + ) # All columns without the id columns of the category tables - expected_colnames = [('test_aligned_table', 'main_c1'), ('test_aligned_table', 'main_c2'), - ('test_aligned_table', 'main_c3'), ('test1', 'c1'), ('test1', 'c2'), - ('test1', 'c3'), ('test2', 'c1'), ('test2', 'c2'), ('test2', 'c3'), - ('test3', 'c1'), ('test3', 'c2'), ('test3', 'c3')] - self.assertListEqual(adt.get_colnames(include_category_tables=True, ignore_category_ids=True), - expected_colnames) + expected_colnames = [ + ("test_aligned_table", "main_c1"), + ("test_aligned_table", "main_c2"), + ("test_aligned_table", "main_c3"), + ("test1", "c1"), + ("test1", "c2"), + ("test1", "c3"), + ("test2", "c1"), + ("test2", "c2"), + ("test2", "c3"), + ("test3", "c1"), + ("test3", "c2"), + ("test3", "c3"), + ] + self.assertListEqual( + adt.get_colnames(include_category_tables=True, ignore_category_ids=True), + expected_colnames, + ) diff --git a/tests/unit/common/test_common.py b/tests/unit/common/test_common.py index 76c99d44a..6d7df5d16 100644 --- a/tests/unit/common/test_common.py +++ b/tests/unit/common/test_common.py @@ -1,13 +1,12 @@ -from hdmf import Data, Container +from hdmf import Container, Data from hdmf.common import get_type_map from hdmf.testing import TestCase class TestCommonTypeMap(TestCase): - def test_base_types(self): tm = get_type_map() - cls = tm.get_dt_container_cls('Container', 'hdmf-common') + cls = tm.get_dt_container_cls("Container", "hdmf-common") self.assertIs(cls, Container) - cls = tm.get_dt_container_cls('Data', 'hdmf-common') + cls = tm.get_dt_container_cls("Data", "hdmf-common") self.assertIs(cls, Data) diff --git a/tests/unit/common/test_common_io.py b/tests/unit/common/test_common_io.py index a3324040e..3c6a5f775 100644 --- a/tests/unit/common/test_common_io.py +++ b/tests/unit/common/test_common_io.py @@ -1,11 +1,11 @@ from h5py import File from hdmf.backends.hdf5 import HDF5IO -from hdmf.common import Container, get_manager, get_hdf5io +from hdmf.common import Container, get_hdf5io, get_manager from hdmf.spec import NamespaceCatalog from hdmf.testing import TestCase, remove_test_file -from tests.unit.helpers.utils import get_temp_filepath +from ..helpers.utils import get_temp_filepath class TestCacheSpec(TestCase): @@ -17,7 +17,7 @@ class TestCacheSpec(TestCase): def setUp(self): self.manager = get_manager() self.path = get_temp_filepath() - self.container = Container('dummy') + self.container = Container("dummy") def tearDown(self): remove_test_file(self.path) @@ -26,24 +26,24 @@ def test_write_no_cache_spec(self): """Roundtrip test for not writing spec.""" with HDF5IO(self.path, manager=self.manager, mode="a") as io: io.write(self.container, cache_spec=False) - with File(self.path, 'r') as f: - self.assertNotIn('specifications', f) + with File(self.path, "r") as f: + self.assertNotIn("specifications", f) def test_write_cache_spec(self): """Roundtrip test for writing spec and reading it back in.""" with HDF5IO(self.path, manager=self.manager, mode="a") as io: io.write(self.container) - with File(self.path, 'r') as f: - self.assertIn('specifications', f) + with File(self.path, "r") as f: + self.assertIn("specifications", f) self._check_spec() def test_write_cache_spec_injected(self): """Roundtrip test for writing spec and reading it back in when HDF5IO is passed an open h5py.File.""" - with File(self.path, 'w') as fil: - with HDF5IO(self.path, manager=self.manager, file=fil, mode='a') as io: + with File(self.path, "w") as fil: + with HDF5IO(self.path, manager=self.manager, file=fil, mode="a") as io: io.write(self.container) - with File(self.path, 'r') as f: - self.assertIn('specifications', f) + with File(self.path, "r") as f: + self.assertIn("specifications", f) self._check_spec() def _check_spec(self): @@ -55,7 +55,7 @@ def _check_spec(self): original_ns = self.manager.namespace_catalog.get_namespace(namespace) cached_ns = ns_catalog.get_namespace(namespace) ns_fields_to_check = list(original_ns.keys()) - ns_fields_to_check.remove('schema') # schema fields will not match, so reset + ns_fields_to_check.remove("schema") # schema fields will not match, so reset for ns_field in ns_fields_to_check: with self.subTest(namespace_field=ns_field): self.assertEqual(original_ns[ns_field], cached_ns[ns_field]) @@ -63,14 +63,13 @@ def _check_spec(self): with self.subTest(data_type=dt): original_spec = original_ns.get_spec(dt) cached_spec = cached_ns.get_spec(dt) - with self.subTest('Data type spec is read back in'): + with self.subTest("Data type spec is read back in"): self.assertIsNotNone(cached_spec) - with self.subTest('Cached spec matches original spec'): + with self.subTest("Cached spec matches original spec"): self.assertDictEqual(original_spec, cached_spec) class TestGetHdf5IO(TestCase): - def setUp(self): self.path = get_temp_filepath() diff --git a/tests/unit/common/test_generate_table.py b/tests/unit/common/test_generate_table.py index 8d76e651d..23e4d47bc 100644 --- a/tests/unit/common/test_generate_table.py +++ b/tests/unit/common/test_generate_table.py @@ -1,129 +1,137 @@ -import numpy as np import os import shutil import tempfile +import numpy as np + from hdmf.backends.hdf5 import HDF5IO from hdmf.build import BuildManager, TypeMap -from hdmf.common import get_type_map, DynamicTable, VectorData -from hdmf.spec import GroupSpec, DatasetSpec, SpecCatalog, SpecNamespace, NamespaceCatalog +from hdmf.common import DynamicTable, VectorData, get_type_map +from hdmf.spec import ( + DatasetSpec, + GroupSpec, + NamespaceCatalog, + SpecCatalog, + SpecNamespace, +) from hdmf.testing import TestCase from hdmf.validate import ValidatorMap -from tests.unit.helpers.utils import CORE_NAMESPACE +from ..helpers.utils import CORE_NAMESPACE class TestDynamicDynamicTable(TestCase): - def setUp(self): self.dt_spec = GroupSpec( - 'A test extension that contains a dynamic table', - data_type_def='TestTable', - data_type_inc='DynamicTable', + "A test extension that contains a dynamic table", + data_type_def="TestTable", + data_type_inc="DynamicTable", datasets=[ DatasetSpec( - data_type_inc='VectorData', - name='my_col', - doc='a test column', - dtype='float' + data_type_inc="VectorData", + name="my_col", + doc="a test column", + dtype="float", ), DatasetSpec( - data_type_inc='VectorData', - name='indexed_col', - doc='a test column', - dtype='float' + data_type_inc="VectorData", + name="indexed_col", + doc="a test column", + dtype="float", ), DatasetSpec( - data_type_inc='VectorIndex', - name='indexed_col_index', - doc='a test column', + data_type_inc="VectorIndex", + name="indexed_col_index", + doc="a test column", ), DatasetSpec( - data_type_inc='VectorData', - name='optional_col1', - doc='a test column', - dtype='float', - quantity='?', + data_type_inc="VectorData", + name="optional_col1", + doc="a test column", + dtype="float", + quantity="?", ), DatasetSpec( - data_type_inc='VectorData', - name='optional_col2', - doc='a test column', - dtype='float', - quantity='?', - ) - ] + data_type_inc="VectorData", + name="optional_col2", + doc="a test column", + dtype="float", + quantity="?", + ), + ], ) self.dt_spec2 = GroupSpec( - 'A test extension that contains a dynamic table', - data_type_def='TestDTRTable', - data_type_inc='DynamicTable', + "A test extension that contains a dynamic table", + data_type_def="TestDTRTable", + data_type_inc="DynamicTable", datasets=[ DatasetSpec( - data_type_inc='DynamicTableRegion', - name='ref_col', - doc='a test column', + data_type_inc="DynamicTableRegion", + name="ref_col", + doc="a test column", ), DatasetSpec( - data_type_inc='DynamicTableRegion', - name='indexed_ref_col', - doc='a test column', + data_type_inc="DynamicTableRegion", + name="indexed_ref_col", + doc="a test column", ), DatasetSpec( - data_type_inc='VectorIndex', - name='indexed_ref_col_index', - doc='a test column', + data_type_inc="VectorIndex", + name="indexed_ref_col_index", + doc="a test column", ), DatasetSpec( - data_type_inc='DynamicTableRegion', - name='optional_ref_col', - doc='a test column', - quantity='?' + data_type_inc="DynamicTableRegion", + name="optional_ref_col", + doc="a test column", + quantity="?", ), DatasetSpec( - data_type_inc='DynamicTableRegion', - name='optional_indexed_ref_col', - doc='a test column', - quantity='?' + data_type_inc="DynamicTableRegion", + name="optional_indexed_ref_col", + doc="a test column", + quantity="?", ), DatasetSpec( - data_type_inc='VectorIndex', - name='optional_indexed_ref_col_index', - doc='a test column', - quantity='?' + data_type_inc="VectorIndex", + name="optional_indexed_ref_col_index", + doc="a test column", + quantity="?", ), DatasetSpec( - data_type_inc='VectorData', - name='optional_col3', - doc='a test column', - dtype='float', - quantity='?', - ) - ] + data_type_inc="VectorData", + name="optional_col3", + doc="a test column", + dtype="float", + quantity="?", + ), + ], ) from hdmf.spec.write import YAMLSpecWriter - writer = YAMLSpecWriter(outdir='.') + + writer = YAMLSpecWriter(outdir=".") self.spec_catalog = SpecCatalog() - self.spec_catalog.register_spec(self.dt_spec, 'test.yaml') - self.spec_catalog.register_spec(self.dt_spec2, 'test.yaml') + self.spec_catalog.register_spec(self.dt_spec, "test.yaml") + self.spec_catalog.register_spec(self.dt_spec2, "test.yaml") self.namespace = SpecNamespace( - 'a test namespace', CORE_NAMESPACE, + "a test namespace", + CORE_NAMESPACE, [ dict( - namespace='hdmf-common', + namespace="hdmf-common", ), - dict(source='test.yaml'), + dict(source="test.yaml"), ], - version='0.1.0', - catalog=self.spec_catalog + version="0.1.0", + catalog=self.spec_catalog, ) self.test_dir = tempfile.mkdtemp() - spec_fpath = os.path.join(self.test_dir, 'test.yaml') - namespace_fpath = os.path.join(self.test_dir, 'test-namespace.yaml') + spec_fpath = os.path.join(self.test_dir, "test.yaml") + namespace_fpath = os.path.join(self.test_dir, "test-namespace.yaml") writer.write_spec(dict(groups=[self.dt_spec, self.dt_spec2]), spec_fpath) writer.write_namespace(self.namespace, namespace_fpath) self.namespace_catalog = NamespaceCatalog() @@ -133,8 +141,8 @@ def setUp(self): self.type_map.load_namespaces(namespace_fpath) self.manager = BuildManager(self.type_map) - self.TestTable = self.type_map.get_dt_container_cls('TestTable', CORE_NAMESPACE) - self.TestDTRTable = self.type_map.get_dt_container_cls('TestDTRTable', CORE_NAMESPACE) + self.TestTable = self.type_map.get_dt_container_cls("TestTable", CORE_NAMESPACE) + self.TestDTRTable = self.type_map.get_dt_container_cls("TestDTRTable", CORE_NAMESPACE) def tearDown(self) -> None: shutil.rmtree(self.test_dir) @@ -143,104 +151,138 @@ def test_dynamic_table(self): assert issubclass(self.TestTable, DynamicTable) assert self.TestTable.__columns__[0] == { - 'name': 'my_col', - 'description': 'a test column', - 'class': VectorData, - 'required': True - } + "name": "my_col", + "description": "a test column", + "class": VectorData, + "required": True, + } def test_forbids_incorrect_col(self): - test_table = self.TestTable(name='test_table', description='my test table') + test_table = self.TestTable(name="test_table", description="my test table") with self.assertRaises(ValueError): test_table.add_row(my_col=3.0, indexed_col=[1.0, 3.0], incorrect_col=5) def test_dynamic_column(self): - test_table = self.TestTable(name='test_table', description='my test table') - test_table.add_column('dynamic_column', 'this is a dynamic column') + test_table = self.TestTable(name="test_table", description="my test table") + test_table.add_column("dynamic_column", "this is a dynamic column") test_table.add_row( - my_col=3.0, indexed_col=[1.0, 3.0], dynamic_column=4, optional_col2=.5, + my_col=3.0, + indexed_col=[1.0, 3.0], + dynamic_column=4, + optional_col2=0.5, ) test_table.add_row( - my_col=4.0, indexed_col=[2.0, 4.0], dynamic_column=4, optional_col2=.5, + my_col=4.0, + indexed_col=[2.0, 4.0], + dynamic_column=4, + optional_col2=0.5, ) - np.testing.assert_array_equal(test_table['indexed_col'].target.data, [1., 3., 2., 4.]) - np.testing.assert_array_equal(test_table['dynamic_column'].data, [4, 4]) + np.testing.assert_array_equal(test_table["indexed_col"].target.data, [1.0, 3.0, 2.0, 4.0]) + np.testing.assert_array_equal(test_table["dynamic_column"].data, [4, 4]) def test_optional_col(self): - test_table = self.TestTable(name='test_table', description='my test table') - test_table.add_row(my_col=3.0, indexed_col=[1.0, 3.0], optional_col2=.5) - test_table.add_row(my_col=4.0, indexed_col=[2.0, 4.0], optional_col2=.5) + test_table = self.TestTable(name="test_table", description="my test table") + test_table.add_row(my_col=3.0, indexed_col=[1.0, 3.0], optional_col2=0.5) + test_table.add_row(my_col=4.0, indexed_col=[2.0, 4.0], optional_col2=0.5) def test_dynamic_table_region(self): - test_table = self.TestTable(name='test_table', description='my test table') - test_table.add_row(my_col=3.0, indexed_col=[1.0, 3.0], optional_col2=.5) - test_table.add_row(my_col=4.0, indexed_col=[2.0, 4.0], optional_col2=.5) - - test_dtr_table = self.TestDTRTable(name='test_dtr_table', description='my table', - target_tables={'ref_col': test_table, - 'indexed_ref_col': test_table}) - self.assertIs(test_dtr_table['ref_col'].table, test_table) - self.assertIs(test_dtr_table['indexed_ref_col'].target.table, test_table) + test_table = self.TestTable(name="test_table", description="my test table") + test_table.add_row(my_col=3.0, indexed_col=[1.0, 3.0], optional_col2=0.5) + test_table.add_row(my_col=4.0, indexed_col=[2.0, 4.0], optional_col2=0.5) + + test_dtr_table = self.TestDTRTable( + name="test_dtr_table", + description="my table", + target_tables={ + "ref_col": test_table, + "indexed_ref_col": test_table, + }, + ) + self.assertIs(test_dtr_table["ref_col"].table, test_table) + self.assertIs(test_dtr_table["indexed_ref_col"].target.table, test_table) test_dtr_table.add_row(ref_col=0, indexed_ref_col=[0, 1]) test_dtr_table.add_row(ref_col=0, indexed_ref_col=[0, 1]) - np.testing.assert_array_equal(test_dtr_table['indexed_ref_col'].target.data, [0, 1, 0, 1]) - np.testing.assert_array_equal(test_dtr_table['ref_col'].data, [0, 0]) + np.testing.assert_array_equal(test_dtr_table["indexed_ref_col"].target.data, [0, 1, 0, 1]) + np.testing.assert_array_equal(test_dtr_table["ref_col"].data, [0, 0]) def test_dynamic_table_region_optional(self): - test_table = self.TestTable(name='test_table', description='my test table') - test_table.add_row(my_col=3.0, indexed_col=[1.0, 3.0], optional_col2=.5) - test_table.add_row(my_col=4.0, indexed_col=[2.0, 4.0], optional_col2=.5) - - test_dtr_table = self.TestDTRTable(name='test_dtr_table', description='my table', - target_tables={'optional_ref_col': test_table, - 'optional_indexed_ref_col': test_table}) - self.assertIs(test_dtr_table['optional_ref_col'].table, test_table) - self.assertIs(test_dtr_table['optional_indexed_ref_col'].target.table, test_table) - - test_dtr_table.add_row(ref_col=0, indexed_ref_col=[0, 1], - optional_ref_col=0, optional_indexed_ref_col=[0, 1]) - test_dtr_table.add_row(ref_col=0, indexed_ref_col=[0, 1], - optional_ref_col=0, optional_indexed_ref_col=[0, 1]) + test_table = self.TestTable(name="test_table", description="my test table") + test_table.add_row(my_col=3.0, indexed_col=[1.0, 3.0], optional_col2=0.5) + test_table.add_row(my_col=4.0, indexed_col=[2.0, 4.0], optional_col2=0.5) + + test_dtr_table = self.TestDTRTable( + name="test_dtr_table", + description="my table", + target_tables={ + "optional_ref_col": test_table, + "optional_indexed_ref_col": test_table, + }, + ) + self.assertIs(test_dtr_table["optional_ref_col"].table, test_table) + self.assertIs(test_dtr_table["optional_indexed_ref_col"].target.table, test_table) + + test_dtr_table.add_row( + ref_col=0, + indexed_ref_col=[0, 1], + optional_ref_col=0, + optional_indexed_ref_col=[0, 1], + ) + test_dtr_table.add_row( + ref_col=0, + indexed_ref_col=[0, 1], + optional_ref_col=0, + optional_indexed_ref_col=[0, 1], + ) - np.testing.assert_array_equal(test_dtr_table['optional_indexed_ref_col'].target.data, [0, 1, 0, 1]) - np.testing.assert_array_equal(test_dtr_table['optional_ref_col'].data, [0, 0]) + np.testing.assert_array_equal(test_dtr_table["optional_indexed_ref_col"].target.data, [0, 1, 0, 1]) + np.testing.assert_array_equal(test_dtr_table["optional_ref_col"].data, [0, 0]) def test_dynamic_table_region_bad_target_col(self): - test_table = self.TestTable(name='test_table', description='my test table') - test_table.add_row(my_col=3.0, indexed_col=[1.0, 3.0], optional_col2=.5) - test_table.add_row(my_col=4.0, indexed_col=[2.0, 4.0], optional_col2=.5) + test_table = self.TestTable(name="test_table", description="my test table") + test_table.add_row(my_col=3.0, indexed_col=[1.0, 3.0], optional_col2=0.5) + test_table.add_row(my_col=4.0, indexed_col=[2.0, 4.0], optional_col2=0.5) msg = r"^'bad' is not the name of a predefined column of table .*" with self.assertRaisesRegex(ValueError, msg): - self.TestDTRTable(name='test_dtr_table', description='my table', target_tables={'bad': test_table}) + self.TestDTRTable( + name="test_dtr_table", + description="my table", + target_tables={"bad": test_table}, + ) def test_dynamic_table_region_non_dtr_target(self): - test_table = self.TestTable(name='test_table', description='my test table') - test_table.add_row(my_col=3.0, indexed_col=[1.0, 3.0], optional_col2=.5) - test_table.add_row(my_col=4.0, indexed_col=[2.0, 4.0], optional_col2=.5) + test_table = self.TestTable(name="test_table", description="my test table") + test_table.add_row(my_col=3.0, indexed_col=[1.0, 3.0], optional_col2=0.5) + test_table.add_row(my_col=4.0, indexed_col=[2.0, 4.0], optional_col2=0.5) msg = "Column 'optional_col3' must be a DynamicTableRegion to have a target table." with self.assertRaisesWith(ValueError, msg): - self.TestDTRTable(name='test_dtr_table', description='my table', - target_tables={'optional_col3': test_table}) + self.TestDTRTable( + name="test_dtr_table", + description="my table", + target_tables={"optional_col3": test_table}, + ) def test_roundtrip(self): # NOTE this does not use H5RoundTripMixin because this requires custom validation - test_table = self.TestTable(name='test_table', description='my test table') - test_table.add_column('dynamic_column', 'this is a dynamic column') + test_table = self.TestTable(name="test_table", description="my test table") + test_table.add_column("dynamic_column", "this is a dynamic column") test_table.add_row( - my_col=3.0, indexed_col=[1.0, 3.0], dynamic_column=4, optional_col2=.5, + my_col=3.0, + indexed_col=[1.0, 3.0], + dynamic_column=4, + optional_col2=0.5, ) - self.filename = os.path.join(self.test_dir, 'test_TestTable.h5') + self.filename = os.path.join(self.test_dir, "test_TestTable.h5") - with HDF5IO(self.filename, manager=self.manager, mode='w') as write_io: + with HDF5IO(self.filename, manager=self.manager, mode="w") as write_io: write_io.write(test_table, cache_spec=True) - self.reader = HDF5IO(self.filename, manager=self.manager, mode='r') + self.reader = HDF5IO(self.filename, manager=self.manager, mode="r") read_container = self.reader.read() self.assertIsNotNone(str(test_table)) # added as a test to make sure printing works diff --git a/tests/unit/common/test_linkedtables.py b/tests/unit/common/test_linkedtables.py index 25a80efa1..1f0b205b6 100644 --- a/tests/unit/common/test_linkedtables.py +++ b/tests/unit/common/test_linkedtables.py @@ -3,72 +3,109 @@ """ import numpy as np -from hdmf.common import DynamicTable, AlignedDynamicTable, VectorData, DynamicTableRegion, VectorIndex -from hdmf.testing import TestCase -from hdmf.utils import docval, popargs, get_docval -from hdmf.common.hierarchicaltable import to_hierarchical_dataframe, drop_id_columns, flatten_column_index from pandas.testing import assert_frame_equal +from hdmf.common import ( + AlignedDynamicTable, + DynamicTable, + DynamicTableRegion, + VectorData, + VectorIndex, +) +from hdmf.common.hierarchicaltable import ( + drop_id_columns, + flatten_column_index, + to_hierarchical_dataframe, +) +from hdmf.testing import TestCase +from hdmf.utils import docval, get_docval, popargs + class DynamicTableSingleDTR(DynamicTable): """Test table class that references a single foreign table""" + __columns__ = ( - {'name': 'child_table_ref1', - 'description': 'Column with a references to the next level in the hierarchy', - 'required': True, - 'index': True, - 'table': True}, + { + "name": "child_table_ref1", + "description": "Column with a references to the next level in the hierarchy", + "required": True, + "index": True, + "table": True, + }, ) - @docval({'name': 'name', 'type': str, 'doc': 'The name of the table'}, - {'name': 'child_table1', - 'type': DynamicTable, - 'doc': 'the child DynamicTable this DynamicTableSingleDTR point to.'}, - *get_docval(DynamicTable.__init__, 'id', 'columns', 'colnames')) + @docval( + { + "name": "name", + "type": str, + "doc": "The name of the table", + }, + { + "name": "child_table1", + "type": DynamicTable, + "doc": "the child DynamicTable this DynamicTableSingleDTR point to.", + }, + *get_docval(DynamicTable.__init__, "id", "columns", "colnames"), + ) def __init__(self, **kwargs): # Define default name and description settings - kwargs['description'] = (kwargs['name'] + " DynamicTableSingleDTR") - child_table1 = popargs('child_table1', kwargs) + kwargs["description"] = kwargs["name"] + " DynamicTableSingleDTR" + child_table1 = popargs("child_table1", kwargs) # Initialize the DynamicTable super().__init__(**kwargs) - if self['child_table_ref1'].target.table is None: - self['child_table_ref1'].target.table = child_table1 + if self["child_table_ref1"].target.table is None: + self["child_table_ref1"].target.table = child_table1 class DynamicTableMultiDTR(DynamicTable): """Test table class that references multiple related tables""" + __columns__ = ( - {'name': 'child_table_ref1', - 'description': 'Column with a references to the next level in the hierarchy', - 'required': True, - 'index': True, - 'table': True}, - {'name': 'child_table_ref2', - 'description': 'Column with a references to the next level in the hierarchy', - 'required': True, - 'index': True, - 'table': True}, + { + "name": "child_table_ref1", + "description": "Column with a references to the next level in the hierarchy", + "required": True, + "index": True, + "table": True, + }, + { + "name": "child_table_ref2", + "description": "Column with a references to the next level in the hierarchy", + "required": True, + "index": True, + "table": True, + }, ) - @docval({'name': 'name', 'type': str, 'doc': 'The name of the table'}, - {'name': 'child_table1', - 'type': DynamicTable, - 'doc': 'the child DynamicTable this DynamicTableSingleDTR point to.'}, - {'name': 'child_table2', - 'type': DynamicTable, - 'doc': 'the child DynamicTable this DynamicTableSingleDTR point to.'}, - *get_docval(DynamicTable.__init__, 'id', 'columns', 'colnames')) + @docval( + { + "name": "name", + "type": str, + "doc": "The name of the table", + }, + { + "name": "child_table1", + "type": DynamicTable, + "doc": "the child DynamicTable this DynamicTableSingleDTR point to.", + }, + { + "name": "child_table2", + "type": DynamicTable, + "doc": "the child DynamicTable this DynamicTableSingleDTR point to.", + }, + *get_docval(DynamicTable.__init__, "id", "columns", "colnames"), + ) def __init__(self, **kwargs): # Define default name and description settings - kwargs['description'] = (kwargs['name'] + " DynamicTableSingleDTR") - child_table1 = popargs('child_table1', kwargs) - child_table2 = popargs('child_table2', kwargs) + kwargs["description"] = kwargs["name"] + " DynamicTableSingleDTR" + child_table1 = popargs("child_table1", kwargs) + child_table2 = popargs("child_table2", kwargs) # Initialize the DynamicTable super().__init__(**kwargs) - if self['child_table_ref1'].target.table is None: - self['child_table_ref1'].target.table = child_table1 - if self['child_table_ref2'].target.table is None: - self['child_table_ref2'].target.table = child_table2 + if self["child_table_ref1"].target.table is None: + self["child_table_ref1"].target.table = child_table1 + if self["child_table_ref2"].target.table is None: + self["child_table_ref2"].target.table = child_table2 class TestLinkedAlignedDynamicTables(TestCase): @@ -80,6 +117,7 @@ class TestLinkedAlignedDynamicTables(TestCase): we test with container class. The only time I/O becomes relevant is on read in case that, e.g., a h5py.Dataset may behave differently than a numpy array. """ + def setUp(self): """ Create basic set of linked tables consisting of @@ -91,59 +129,87 @@ def setUp(self): +--> category1 ---> table_level_0_1 """ # Level 0 0 table. I.e., first table on level 0 - self.table_level0_0 = DynamicTable(name='level0_0', description="level0_0 DynamicTable") + self.table_level0_0 = DynamicTable(name="level0_0", description="level0_0 DynamicTable") self.table_level0_0.add_row(id=10) self.table_level0_0.add_row(id=11) self.table_level0_0.add_row(id=12) self.table_level0_0.add_row(id=13) - self.table_level0_0.add_column(data=['tag1', 'tag2', 'tag2', 'tag1', 'tag3', 'tag4', 'tag5'], - name='tags', - description='custom tags', - index=[1, 2, 4, 7]) - self.table_level0_0.add_column(data=np.arange(4), - name='myid', - description='custom ids', - index=False) + self.table_level0_0.add_column( + data=["tag1", "tag2", "tag2", "tag1", "tag3", "tag4", "tag5"], + name="tags", + description="custom tags", + index=[1, 2, 4, 7], + ) + self.table_level0_0.add_column( + data=np.arange(4), + name="myid", + description="custom ids", + index=False, + ) # Level 0 1 table. I.e., second table on level 0 - self.table_level0_1 = DynamicTable(name='level0_1', description="level0_1 DynamicTable") + self.table_level0_1 = DynamicTable(name="level0_1", description="level0_1 DynamicTable") self.table_level0_1.add_row(id=14) self.table_level0_1.add_row(id=15) self.table_level0_1.add_row(id=16) self.table_level0_1.add_row(id=17) - self.table_level0_1.add_column(data=['tag1', 'tag1', 'tag2', 'tag2', 'tag3', 'tag3', 'tag4'], - name='tags', - description='custom tags', - index=[2, 4, 6, 7]) - self.table_level0_1.add_column(data=np.arange(4), - name='myid', - description='custom ids', - index=False) + self.table_level0_1.add_column( + data=["tag1", "tag1", "tag2", "tag2", "tag3", "tag3", "tag4"], + name="tags", + description="custom tags", + index=[2, 4, 6, 7], + ) + self.table_level0_1.add_column( + data=np.arange(4), + name="myid", + description="custom ids", + index=False, + ) # category 0 table - self.category0 = DynamicTableSingleDTR(name='category0', child_table1=self.table_level0_0) - self.category0.add_row(id=0, child_table_ref1=[0, ]) + self.category0 = DynamicTableSingleDTR(name="category0", child_table1=self.table_level0_0) + self.category0.add_row( + id=0, + child_table_ref1=[ + 0, + ], + ) self.category0.add_row(id=1, child_table_ref1=[1, 2]) - self.category0.add_row(id=1, child_table_ref1=[3, ]) - self.category0.add_column(data=[10, 11, 12], - name='filter', - description='filter value', - index=False) + self.category0.add_row( + id=1, + child_table_ref1=[ + 3, + ], + ) + self.category0.add_column( + data=[10, 11, 12], + name="filter", + description="filter value", + index=False, + ) # category 1 table - self.category1 = DynamicTableSingleDTR(name='category1', child_table1=self.table_level0_1) + self.category1 = DynamicTableSingleDTR(name="category1", child_table1=self.table_level0_1) self.category1.add_row(id=0, child_table_ref1=[0, 1]) self.category1.add_row(id=1, child_table_ref1=[2, 3]) self.category1.add_row(id=1, child_table_ref1=[1, 3]) - self.category1.add_column(data=[1, 2, 3], - name='filter', - description='filter value', - index=False) + self.category1.add_column( + data=[1, 2, 3], + name="filter", + description="filter value", + index=False, + ) # Aligned table - self.aligned_table = AlignedDynamicTable(name='my_aligned_table', - description='my test table', - columns=[VectorData(name='a1', description='a1', data=np.arange(3)), ], - colnames=['a1', ], - category_tables=[self.category0, self.category1]) + self.aligned_table = AlignedDynamicTable( + name="my_aligned_table", + description="my test table", + columns=[ + VectorData(name="a1", description="a1", data=np.arange(3)), + ], + colnames=[ + "a1", + ], + category_tables=[self.category0, self.category1], + ) def tearDown(self): del self.table_level0_0 @@ -155,21 +221,35 @@ def tearDown(self): def test_to_hierarchical_dataframe(self): """Test that converting an AlignedDynamicTable with links works""" hier_df = to_hierarchical_dataframe(self.aligned_table) - self.assertListEqual(hier_df.columns.to_list(), - [('level0_0', 'id'), ('level0_0', 'tags'), ('level0_0', 'myid')]) - self.assertListEqual(hier_df.index.names, - [('my_aligned_table', 'id'), ('my_aligned_table', ('my_aligned_table', 'a1')), - ('my_aligned_table', ('category0', 'id')), ('my_aligned_table', ('category0', 'filter')), - ('my_aligned_table', ('category1', 'id')), - ('my_aligned_table', ('category1', 'child_table_ref1')), - ('my_aligned_table', ('category1', 'filter'))]) - self.assertListEqual(hier_df.index.to_list(), - [(0, 0, 0, 10, 0, (0, 1), 1), - (1, 1, 1, 11, 1, (2, 3), 2), - (1, 1, 1, 11, 1, (2, 3), 2), - (2, 2, 1, 12, 1, (1, 3), 3)]) - self.assertListEqual(hier_df[('level0_0', 'tags')].values.tolist(), - [['tag1'], ['tag2'], ['tag2', 'tag1'], ['tag3', 'tag4', 'tag5']]) + self.assertListEqual( + hier_df.columns.to_list(), + [("level0_0", "id"), ("level0_0", "tags"), ("level0_0", "myid")], + ) + self.assertListEqual( + hier_df.index.names, + [ + ("my_aligned_table", "id"), + ("my_aligned_table", ("my_aligned_table", "a1")), + ("my_aligned_table", ("category0", "id")), + ("my_aligned_table", ("category0", "filter")), + ("my_aligned_table", ("category1", "id")), + ("my_aligned_table", ("category1", "child_table_ref1")), + ("my_aligned_table", ("category1", "filter")), + ], + ) + self.assertListEqual( + hier_df.index.to_list(), + [ + (0, 0, 0, 10, 0, (0, 1), 1), + (1, 1, 1, 11, 1, (2, 3), 2), + (1, 1, 1, 11, 1, (2, 3), 2), + (2, 2, 1, 12, 1, (1, 3), 3), + ], + ) + self.assertListEqual( + hier_df[("level0_0", "tags")].values.tolist(), + [["tag1"], ["tag2"], ["tag2", "tag1"], ["tag3", "tag4", "tag5"]], + ) def test_has_foreign_columns_in_category_tables(self): """Test confirming working order for DynamicTableRegions in subtables""" @@ -178,31 +258,53 @@ def test_has_foreign_columns_in_category_tables(self): def test_has_foreign_columns_false(self): """Test false if there are no DynamicTableRegionColumns""" - temp_table = DynamicTable(name='t1', description='t1', - colnames=['c1', 'c2'], - columns=[VectorData(name='c1', description='c1', data=np.arange(4)), - VectorData(name='c2', description='c2', data=np.arange(4))]) - temp_aligned_table = AlignedDynamicTable(name='my_aligned_table', - description='my test table', - category_tables=[temp_table], - colnames=['a1', 'a2'], - columns=[VectorData(name='a1', description='c1', data=np.arange(4)), - VectorData(name='a2', description='c2', data=np.arange(4))]) + temp_table = DynamicTable( + name="t1", + description="t1", + colnames=["c1", "c2"], + columns=[ + VectorData(name="c1", description="c1", data=np.arange(4)), + VectorData(name="c2", description="c2", data=np.arange(4)), + ], + ) + temp_aligned_table = AlignedDynamicTable( + name="my_aligned_table", + description="my test table", + category_tables=[temp_table], + colnames=["a1", "a2"], + columns=[ + VectorData(name="a1", description="c1", data=np.arange(4)), + VectorData(name="a2", description="c2", data=np.arange(4)), + ], + ) self.assertFalse(temp_aligned_table.has_foreign_columns()) self.assertFalse(temp_aligned_table.has_foreign_columns(ignore_category_tables=True)) def test_has_foreign_column_in_main_table(self): - temp_table = DynamicTable(name='t1', description='t1', - colnames=['c1', 'c2'], - columns=[VectorData(name='c1', description='c1', data=np.arange(4)), - VectorData(name='c2', description='c2', data=np.arange(4))]) - temp_aligned_table = AlignedDynamicTable(name='my_aligned_table', - description='my test table', - category_tables=[temp_table], - colnames=['a1', 'a2'], - columns=[VectorData(name='a1', description='c1', data=np.arange(4)), - DynamicTableRegion(name='a2', description='c2', - data=np.arange(4), table=temp_table)]) + temp_table = DynamicTable( + name="t1", + description="t1", + colnames=["c1", "c2"], + columns=[ + VectorData(name="c1", description="c1", data=np.arange(4)), + VectorData(name="c2", description="c2", data=np.arange(4)), + ], + ) + temp_aligned_table = AlignedDynamicTable( + name="my_aligned_table", + description="my test table", + category_tables=[temp_table], + colnames=["a1", "a2"], + columns=[ + VectorData(name="a1", description="c1", data=np.arange(4)), + DynamicTableRegion( + name="a2", + description="c2", + data=np.arange(4), + table=temp_table, + ), + ], + ) self.assertTrue(temp_aligned_table.has_foreign_columns()) self.assertTrue(temp_aligned_table.has_foreign_columns(ignore_category_tables=True)) @@ -213,45 +315,90 @@ def test_get_foreign_columns(self): # check with subcateogries foreign_cols = self.aligned_table.get_foreign_columns() self.assertEqual(len(foreign_cols), 2) - for i, v in enumerate([('category0', 'child_table_ref1'), ('category1', 'child_table_ref1')]): + for i, v in enumerate( + [ + ("category0", "child_table_ref1"), + ("category1", "child_table_ref1"), + ] + ): self.assertTupleEqual(foreign_cols[i], v) def test_get_foreign_columns_none(self): """Test false if there are no DynamicTableRegionColumns""" - temp_table = DynamicTable(name='t1', description='t1', - colnames=['c1', 'c2'], - columns=[VectorData(name='c1', description='c1', data=np.arange(4)), - VectorData(name='c2', description='c2', data=np.arange(4))]) - temp_aligned_table = AlignedDynamicTable(name='my_aligned_table', - description='my test table', - category_tables=[temp_table], - colnames=['a1', 'a2'], - columns=[VectorData(name='a1', description='c1', data=np.arange(4)), - VectorData(name='a2', description='c2', data=np.arange(4))]) + temp_table = DynamicTable( + name="t1", + description="t1", + colnames=["c1", "c2"], + columns=[ + VectorData(name="c1", description="c1", data=np.arange(4)), + VectorData(name="c2", description="c2", data=np.arange(4)), + ], + ) + temp_aligned_table = AlignedDynamicTable( + name="my_aligned_table", + description="my test table", + category_tables=[temp_table], + colnames=["a1", "a2"], + columns=[ + VectorData(name="a1", description="c1", data=np.arange(4)), + VectorData(name="a2", description="c2", data=np.arange(4)), + ], + ) self.assertListEqual(temp_aligned_table.get_foreign_columns(), []) - self.assertListEqual(temp_aligned_table.get_foreign_columns(ignore_category_tables=True), []) + self.assertListEqual( + temp_aligned_table.get_foreign_columns(ignore_category_tables=True), + [], + ) def test_get_foreign_column_in_main_and_category_table(self): - temp_table0 = DynamicTable(name='t0', description='t1', - colnames=['c1', 'c2'], - columns=[VectorData(name='c1', description='c1', data=np.arange(4)), - VectorData(name='c2', description='c2', data=np.arange(4))]) - temp_table = DynamicTable(name='t1', description='t1', - colnames=['c1', 'c2'], - columns=[VectorData(name='c1', description='c1', data=np.arange(4)), - DynamicTableRegion(name='c2', description='c2', - data=np.arange(4), table=temp_table0)]) - temp_aligned_table = AlignedDynamicTable(name='my_aligned_table', - description='my test table', - category_tables=[temp_table], - colnames=['a1', 'a2'], - columns=[VectorData(name='a1', description='c1', data=np.arange(4)), - DynamicTableRegion(name='a2', description='c2', - data=np.arange(4), table=temp_table)]) + temp_table0 = DynamicTable( + name="t0", + description="t1", + colnames=["c1", "c2"], + columns=[ + VectorData(name="c1", description="c1", data=np.arange(4)), + VectorData(name="c2", description="c2", data=np.arange(4)), + ], + ) + temp_table = DynamicTable( + name="t1", + description="t1", + colnames=["c1", "c2"], + columns=[ + VectorData(name="c1", description="c1", data=np.arange(4)), + DynamicTableRegion( + name="c2", + description="c2", + data=np.arange(4), + table=temp_table0, + ), + ], + ) + temp_aligned_table = AlignedDynamicTable( + name="my_aligned_table", + description="my test table", + category_tables=[temp_table], + colnames=["a1", "a2"], + columns=[ + VectorData(name="a1", description="c1", data=np.arange(4)), + DynamicTableRegion( + name="a2", + description="c2", + data=np.arange(4), + table=temp_table, + ), + ], + ) # We should get both the DynamicTableRegion from the main table and the category 't1' - self.assertListEqual(temp_aligned_table.get_foreign_columns(), [(None, 'a2'), ('t1', 'c2')]) + self.assertListEqual( + temp_aligned_table.get_foreign_columns(), + [(None, "a2"), ("t1", "c2")], + ) # We should only get the column from the main table - self.assertListEqual(temp_aligned_table.get_foreign_columns(ignore_category_tables=True), [(None, 'a2')]) + self.assertListEqual( + temp_aligned_table.get_foreign_columns(ignore_category_tables=True), + [(None, "a2")], + ) def test_get_linked_tables(self): # check without subcateogries @@ -260,103 +407,190 @@ def test_get_linked_tables(self): # check with subcateogries linked_tables = self.aligned_table.get_linked_tables() self.assertEqual(len(linked_tables), 2) - self.assertTupleEqual((linked_tables[0].source_table.name, - linked_tables[0].source_column.name, - linked_tables[0].target_table.name), - ('category0', 'child_table_ref1', 'level0_0')) - self.assertTupleEqual((linked_tables[1].source_table.name, - linked_tables[1].source_column.name, - linked_tables[1].target_table.name), - ('category1', 'child_table_ref1', 'level0_1')) + self.assertTupleEqual( + ( + linked_tables[0].source_table.name, + linked_tables[0].source_column.name, + linked_tables[0].target_table.name, + ), + ("category0", "child_table_ref1", "level0_0"), + ) + self.assertTupleEqual( + ( + linked_tables[1].source_table.name, + linked_tables[1].source_column.name, + linked_tables[1].target_table.name, + ), + ("category1", "child_table_ref1", "level0_1"), + ) def test_get_linked_tables_none(self): """Test false if there are no DynamicTableRegionColumns""" - temp_table = DynamicTable(name='t1', description='t1', - colnames=['c1', 'c2'], - columns=[VectorData(name='c1', description='c1', data=np.arange(4)), - VectorData(name='c2', description='c2', data=np.arange(4))]) - temp_aligned_table = AlignedDynamicTable(name='my_aligned_table', - description='my test table', - category_tables=[temp_table], - colnames=['a1', 'a2'], - columns=[VectorData(name='a1', description='c1', data=np.arange(4)), - VectorData(name='a2', description='c2', data=np.arange(4))]) + temp_table = DynamicTable( + name="t1", + description="t1", + colnames=["c1", "c2"], + columns=[ + VectorData(name="c1", description="c1", data=np.arange(4)), + VectorData(name="c2", description="c2", data=np.arange(4)), + ], + ) + temp_aligned_table = AlignedDynamicTable( + name="my_aligned_table", + description="my test table", + category_tables=[temp_table], + colnames=["a1", "a2"], + columns=[ + VectorData(name="a1", description="c1", data=np.arange(4)), + VectorData(name="a2", description="c2", data=np.arange(4)), + ], + ) self.assertListEqual(temp_aligned_table.get_linked_tables(), []) - self.assertListEqual(temp_aligned_table.get_linked_tables(ignore_category_tables=True), []) + self.assertListEqual( + temp_aligned_table.get_linked_tables(ignore_category_tables=True), + [], + ) def test_get_linked_tables_complex_link(self): - temp_table0 = DynamicTable(name='t0', description='t1', - colnames=['c1', 'c2'], - columns=[VectorData(name='c1', description='c1', data=np.arange(4)), - VectorData(name='c2', description='c2', data=np.arange(4))]) - temp_table = DynamicTable(name='t1', description='t1', - colnames=['c1', 'c2'], - columns=[VectorData(name='c1', description='c1', data=np.arange(4)), - DynamicTableRegion(name='c2', description='c2', - data=np.arange(4), table=temp_table0)]) - temp_aligned_table = AlignedDynamicTable(name='my_aligned_table', - description='my test table', - category_tables=[temp_table], - colnames=['a1', 'a2'], - columns=[VectorData(name='a1', description='c1', data=np.arange(4)), - DynamicTableRegion(name='a2', description='c2', - data=np.arange(4), table=temp_table)]) + temp_table0 = DynamicTable( + name="t0", + description="t1", + colnames=["c1", "c2"], + columns=[ + VectorData(name="c1", description="c1", data=np.arange(4)), + VectorData(name="c2", description="c2", data=np.arange(4)), + ], + ) + temp_table = DynamicTable( + name="t1", + description="t1", + colnames=["c1", "c2"], + columns=[ + VectorData(name="c1", description="c1", data=np.arange(4)), + DynamicTableRegion( + name="c2", + description="c2", + data=np.arange(4), + table=temp_table0, + ), + ], + ) + temp_aligned_table = AlignedDynamicTable( + name="my_aligned_table", + description="my test table", + category_tables=[temp_table], + colnames=["a1", "a2"], + columns=[ + VectorData(name="a1", description="c1", data=np.arange(4)), + DynamicTableRegion( + name="a2", + description="c2", + data=np.arange(4), + table=temp_table, + ), + ], + ) # NOTE: in this example templ_aligned_table both points to temp_table and at the # same time contains temp_table as a category. This could lead to temp_table # visited multiple times and we want to make sure this doesn't happen # We should get both the DynamicTableRegion from the main table and the category 't1' linked_tables = temp_aligned_table.get_linked_tables() self.assertEqual(len(linked_tables), 2) - for i, v in enumerate([('my_aligned_table', 'a2', 't1'), ('t1', 'c2', 't0')]): - self.assertTupleEqual((linked_tables[i].source_table.name, - linked_tables[i].source_column.name, - linked_tables[i].target_table.name), v) + for i, v in enumerate([("my_aligned_table", "a2", "t1"), ("t1", "c2", "t0")]): + self.assertTupleEqual( + ( + linked_tables[i].source_table.name, + linked_tables[i].source_column.name, + linked_tables[i].target_table.name, + ), + v, + ) # Now, since our main table links to the category table the result should remain the same # even if we ignore the category table linked_tables = temp_aligned_table.get_linked_tables(ignore_category_tables=True) self.assertEqual(len(linked_tables), 2) - for i, v in enumerate([('my_aligned_table', 'a2', 't1'), ('t1', 'c2', 't0')]): - self.assertTupleEqual((linked_tables[i].source_table.name, - linked_tables[i].source_column.name, - linked_tables[i].target_table.name), v) + for i, v in enumerate([("my_aligned_table", "a2", "t1"), ("t1", "c2", "t0")]): + self.assertTupleEqual( + ( + linked_tables[i].source_table.name, + linked_tables[i].source_column.name, + linked_tables[i].target_table.name, + ), + v, + ) def test_get_linked_tables_simple_link(self): - temp_table0 = DynamicTable(name='t0', description='t1', - colnames=['c1', 'c2'], - columns=[VectorData(name='c1', description='c1', data=np.arange(4)), - VectorData(name='c2', description='c2', data=np.arange(4))]) - temp_table = DynamicTable(name='t1', description='t1', - colnames=['c1', 'c2'], - columns=[VectorData(name='c1', description='c1', data=np.arange(4)), - DynamicTableRegion(name='c2', description='c2', - data=np.arange(4), table=temp_table0)]) - temp_aligned_table = AlignedDynamicTable(name='my_aligned_table', - description='my test table', - category_tables=[temp_table], - colnames=['a1', 'a2'], - columns=[VectorData(name='a1', description='c1', data=np.arange(4)), - DynamicTableRegion(name='a2', description='c2', - data=np.arange(4), table=temp_table0)]) + temp_table0 = DynamicTable( + name="t0", + description="t1", + colnames=["c1", "c2"], + columns=[ + VectorData(name="c1", description="c1", data=np.arange(4)), + VectorData(name="c2", description="c2", data=np.arange(4)), + ], + ) + temp_table = DynamicTable( + name="t1", + description="t1", + colnames=["c1", "c2"], + columns=[ + VectorData(name="c1", description="c1", data=np.arange(4)), + DynamicTableRegion( + name="c2", + description="c2", + data=np.arange(4), + table=temp_table0, + ), + ], + ) + temp_aligned_table = AlignedDynamicTable( + name="my_aligned_table", + description="my test table", + category_tables=[temp_table], + colnames=["a1", "a2"], + columns=[ + VectorData(name="a1", description="c1", data=np.arange(4)), + DynamicTableRegion( + name="a2", + description="c2", + data=np.arange(4), + table=temp_table0, + ), + ], + ) # NOTE: in this example temp_aligned_table and temp_table both point to temp_table0 # We should get both the DynamicTableRegion from the main table and the category 't1' linked_tables = temp_aligned_table.get_linked_tables() self.assertEqual(len(linked_tables), 2) - for i, v in enumerate([('my_aligned_table', 'a2', 't0'), ('t1', 'c2', 't0')]): - self.assertTupleEqual((linked_tables[i].source_table.name, - linked_tables[i].source_column.name, - linked_tables[i].target_table.name), v) + for i, v in enumerate([("my_aligned_table", "a2", "t0"), ("t1", "c2", "t0")]): + self.assertTupleEqual( + ( + linked_tables[i].source_table.name, + linked_tables[i].source_column.name, + linked_tables[i].target_table.name, + ), + v, + ) # Since no table ever link to our category temp_table we should only get the link from our # main table here, in contrast to what happens in the test_get_linked_tables_complex_link case linked_tables = temp_aligned_table.get_linked_tables() self.assertEqual(len(linked_tables), 2) - for i, v in enumerate([('my_aligned_table', 'a2', 't0'), ]): - self.assertTupleEqual((linked_tables[i].source_table.name, - linked_tables[i].source_column.name, - linked_tables[i].target_table.name), v) + for i, v in enumerate( + [ + ("my_aligned_table", "a2", "t0"), + ] + ): + self.assertTupleEqual( + ( + linked_tables[i].source_table.name, + linked_tables[i].source_column.name, + linked_tables[i].target_table.name, + ), + v, + ) class TestHierarchicalTable(TestCase): - def setUp(self): """ Create basic set of linked tables consisting of @@ -366,40 +600,70 @@ def setUp(self): +--> category0 """ # Level 0 0 table. I.e., first table on level 0 - self.category0 = DynamicTable(name='level0_0', description="level0_0 DynamicTable") + self.category0 = DynamicTable(name="level0_0", description="level0_0 DynamicTable") self.category0.add_row(id=10) self.category0.add_row(id=11) self.category0.add_row(id=12) self.category0.add_row(id=13) - self.category0.add_column(data=['tag1', 'tag2', 'tag2', 'tag1', 'tag3', 'tag4', 'tag5'], - name='tags', - description='custom tags', - index=[1, 2, 4, 7]) - self.category0.add_column(data=np.arange(4), - name='myid', - description='custom ids', - index=False) + self.category0.add_column( + data=["tag1", "tag2", "tag2", "tag1", "tag3", "tag4", "tag5"], + name="tags", + description="custom tags", + index=[1, 2, 4, 7], + ) + self.category0.add_column( + data=np.arange(4), + name="myid", + description="custom ids", + index=False, + ) # Aligned table - self.aligned_table = AlignedDynamicTable(name='aligned_table', - description='parent_table', - columns=[VectorData(name='a1', description='a1', data=np.arange(4)), ], - colnames=['a1', ], - category_tables=[self.category0, ]) + self.aligned_table = AlignedDynamicTable( + name="aligned_table", + description="parent_table", + columns=[ + VectorData(name="a1", description="a1", data=np.arange(4)), + ], + colnames=[ + "a1", + ], + category_tables=[ + self.category0, + ], + ) # Parent table - self.parent_table = DynamicTable(name='parent_table', - description='parent_table', - columns=[VectorData(name='p1', description='p1', data=np.arange(4)), - DynamicTableRegion(name='l1', description='l1', - data=np.arange(4), table=self.aligned_table)]) + self.parent_table = DynamicTable( + name="parent_table", + description="parent_table", + columns=[ + VectorData(name="p1", description="p1", data=np.arange(4)), + DynamicTableRegion( + name="l1", + description="l1", + data=np.arange(4), + table=self.aligned_table, + ), + ], + ) # Super-parent table - dtr_sp = DynamicTableRegion(name='sl1', description='sl1', data=np.arange(4), table=self.parent_table) - vi_dtr_sp = VectorIndex(name='sl1_index', data=[1, 2, 3], target=dtr_sp) - self.super_parent_table = DynamicTable(name='super_parent_table', - description='super_parent_table', - columns=[VectorData(name='sp1', description='sp1', data=np.arange(3)), - dtr_sp, vi_dtr_sp]) + dtr_sp = DynamicTableRegion( + name="sl1", + description="sl1", + data=np.arange(4), + table=self.parent_table, + ) + vi_dtr_sp = VectorIndex(name="sl1_index", data=[1, 2, 3], target=dtr_sp) + self.super_parent_table = DynamicTable( + name="super_parent_table", + description="super_parent_table", + columns=[ + VectorData(name="sp1", description="sp1", data=np.arange(3)), + dtr_sp, + vi_dtr_sp, + ], + ) def tearDown(self): del self.category0 @@ -408,38 +672,81 @@ def tearDown(self): def test_to_hierarchical_dataframe_no_dtr_on_top_level(self): # Cover the case where our top dtr is flat (i.e., without a VectorIndex) - dtr_sp = DynamicTableRegion(name='sl1', description='sl1', data=np.arange(4), table=self.parent_table) - spttable = DynamicTable(name='super_parent_table', - description='super_parent_table', - columns=[VectorData(name='sp1', description='sp1', data=np.arange(4)), dtr_sp]) + dtr_sp = DynamicTableRegion( + name="sl1", + description="sl1", + data=np.arange(4), + table=self.parent_table, + ) + spttable = DynamicTable( + name="super_parent_table", + description="super_parent_table", + columns=[ + VectorData(name="sp1", description="sp1", data=np.arange(4)), + dtr_sp, + ], + ) hier_df = to_hierarchical_dataframe(spttable).reset_index() - expected_columns = [('super_parent_table', 'id'), ('super_parent_table', 'sp1'), - ('parent_table', 'id'), ('parent_table', 'p1'), - ('aligned_table', 'id'), - ('aligned_table', ('aligned_table', 'a1')), ('aligned_table', ('level0_0', 'id')), - ('aligned_table', ('level0_0', 'tags')), ('aligned_table', ('level0_0', 'myid'))] + expected_columns = [ + ("super_parent_table", "id"), + ("super_parent_table", "sp1"), + ("parent_table", "id"), + ("parent_table", "p1"), + ("aligned_table", "id"), + ("aligned_table", ("aligned_table", "a1")), + ("aligned_table", ("level0_0", "id")), + ("aligned_table", ("level0_0", "tags")), + ("aligned_table", ("level0_0", "myid")), + ] self.assertListEqual(hier_df.columns.to_list(), expected_columns) def test_to_hierarchical_dataframe_indexed_dtr_on_last_level(self): # Parent table - dtr_p1 = DynamicTableRegion(name='l1', description='l1', data=np.arange(4), table=self.aligned_table) - vi_dtr_p1 = VectorIndex(name='sl1_index', data=[1, 2, 3], target=dtr_p1) - p1 = DynamicTable(name='parent_table', description='parent_table', - columns=[VectorData(name='p1', description='p1', data=np.arange(3)), dtr_p1, vi_dtr_p1]) + dtr_p1 = DynamicTableRegion( + name="l1", + description="l1", + data=np.arange(4), + table=self.aligned_table, + ) + vi_dtr_p1 = VectorIndex(name="sl1_index", data=[1, 2, 3], target=dtr_p1) + p1 = DynamicTable( + name="parent_table", + description="parent_table", + columns=[ + VectorData(name="p1", description="p1", data=np.arange(3)), + dtr_p1, + vi_dtr_p1, + ], + ) # Super-parent table - dtr_sp = DynamicTableRegion(name='sl1', description='sl1', data=np.arange(4), table=p1) - vi_dtr_sp = VectorIndex(name='sl1_index', data=[1, 2, 3], target=dtr_sp) - spt = DynamicTable(name='super_parent_table', description='super_parent_table', - columns=[VectorData(name='sp1', description='sp1', data=np.arange(3)), dtr_sp, vi_dtr_sp]) + dtr_sp = DynamicTableRegion(name="sl1", description="sl1", data=np.arange(4), table=p1) + vi_dtr_sp = VectorIndex(name="sl1_index", data=[1, 2, 3], target=dtr_sp) + spt = DynamicTable( + name="super_parent_table", + description="super_parent_table", + columns=[ + VectorData(name="sp1", description="sp1", data=np.arange(3)), + dtr_sp, + vi_dtr_sp, + ], + ) hier_df = to_hierarchical_dataframe(spt).reset_index() - expected_columns = [('super_parent_table', 'id'), ('super_parent_table', 'sp1'), - ('parent_table', 'id'), ('parent_table', 'p1'), - ('aligned_table', 'id'), - ('aligned_table', ('aligned_table', 'a1')), ('aligned_table', ('level0_0', 'id')), - ('aligned_table', ('level0_0', 'tags')), ('aligned_table', ('level0_0', 'myid'))] + expected_columns = [ + ("super_parent_table", "id"), + ("super_parent_table", "sp1"), + ("parent_table", "id"), + ("parent_table", "p1"), + ("aligned_table", "id"), + ("aligned_table", ("aligned_table", "a1")), + ("aligned_table", ("level0_0", "id")), + ("aligned_table", ("level0_0", "tags")), + ("aligned_table", ("level0_0", "myid")), + ] self.assertListEqual(hier_df.columns.to_list(), expected_columns) # make sure we have the right columns - self.assertListEqual(hier_df[('aligned_table', ('level0_0', 'tags'))].to_list(), - [['tag1'], ['tag2'], ['tag2', 'tag1']]) + self.assertListEqual( + hier_df[("aligned_table", ("level0_0", "tags"))].to_list(), + [["tag1"], ["tag2"], ["tag2", "tag1"]], + ) def test_to_hierarchical_dataframe_indexed_data_nparray(self): # Test that we can convert a table that contains a VectorIndex column as regular data, @@ -448,24 +755,51 @@ def test_to_hierarchical_dataframe_indexed_data_nparray(self): # into the MultiIndex of the table. As a numpy array is not hashable this would normally # create an error when creating the MultiIndex # Parent table - dtr_p1 = DynamicTableRegion(name='l1', description='l1', data=np.arange(4), table=self.aligned_table) - vi_dtr_p1 = VectorIndex(name='sl1_index', data=[1, 2, 3], target=dtr_p1) - p1 = DynamicTable(name='parent_table', description='parent_table', - columns=[VectorData(name='p1', description='p1', data=np.arange(3)), dtr_p1, vi_dtr_p1]) + dtr_p1 = DynamicTableRegion( + name="l1", + description="l1", + data=np.arange(4), + table=self.aligned_table, + ) + vi_dtr_p1 = VectorIndex(name="sl1_index", data=[1, 2, 3], target=dtr_p1) + p1 = DynamicTable( + name="parent_table", + description="parent_table", + columns=[ + VectorData(name="p1", description="p1", data=np.arange(3)), + dtr_p1, + vi_dtr_p1, + ], + ) # Super-parent table - dtr_sp = DynamicTableRegion(name='sl1', description='sl1', data=np.arange(3), table=p1) - spt = DynamicTable(name='super_parent_table', description='super_parent_table', - columns=[VectorData(name='sp1', description='sp1', data=np.arange(3)), dtr_sp]) - spt.add_column(name='vic', description='vic', data=np.arange(9), index=[2, 4, 6]) + dtr_sp = DynamicTableRegion(name="sl1", description="sl1", data=np.arange(3), table=p1) + spt = DynamicTable( + name="super_parent_table", + description="super_parent_table", + columns=[ + VectorData(name="sp1", description="sp1", data=np.arange(3)), + dtr_sp, + ], + ) + spt.add_column(name="vic", description="vic", data=np.arange(9), index=[2, 4, 6]) hier_df = to_hierarchical_dataframe(spt).reset_index() - expected_columns = [('super_parent_table', 'id'), ('super_parent_table', 'sp1'), ('super_parent_table', 'vic'), - ('parent_table', 'id'), ('parent_table', 'p1'), - ('aligned_table', 'id'), - ('aligned_table', ('aligned_table', 'a1')), ('aligned_table', ('level0_0', 'id')), - ('aligned_table', ('level0_0', 'tags')), ('aligned_table', ('level0_0', 'myid'))] + expected_columns = [ + ("super_parent_table", "id"), + ("super_parent_table", "sp1"), + ("super_parent_table", "vic"), + ("parent_table", "id"), + ("parent_table", "p1"), + ("aligned_table", "id"), + ("aligned_table", ("aligned_table", "a1")), + ("aligned_table", ("level0_0", "id")), + ("aligned_table", ("level0_0", "tags")), + ("aligned_table", ("level0_0", "myid")), + ] self.assertListEqual(hier_df.columns.to_list(), expected_columns) # make sure we have the right columns - self.assertListEqual(hier_df[('aligned_table', ('level0_0', 'tags'))].to_list(), - [['tag1'], ['tag2'], ['tag2', 'tag1']]) + self.assertListEqual( + hier_df[("aligned_table", ("level0_0", "tags"))].to_list(), + [["tag1"], ["tag2"], ["tag2", "tag1"]], + ) def test_to_hierarchical_dataframe_indexed_data_list(self): # Test that we can convert a table that contains a VectorIndex column as regular data, @@ -474,80 +808,152 @@ def test_to_hierarchical_dataframe_indexed_data_list(self): # into the MultiIndex of the table. As a list is not hashable this would normally # create an error when creating the MultiIndex # Parent table - dtr_p1 = DynamicTableRegion(name='l1', description='l1', data=np.arange(4), table=self.aligned_table) - vi_dtr_p1 = VectorIndex(name='sl1_index', data=[1, 2, 3], target=dtr_p1) - p1 = DynamicTable(name='parent_table', description='parent_table', - columns=[VectorData(name='p1', description='p1', data=np.arange(3)), dtr_p1, vi_dtr_p1]) + dtr_p1 = DynamicTableRegion( + name="l1", + description="l1", + data=np.arange(4), + table=self.aligned_table, + ) + vi_dtr_p1 = VectorIndex(name="sl1_index", data=[1, 2, 3], target=dtr_p1) + p1 = DynamicTable( + name="parent_table", + description="parent_table", + columns=[ + VectorData(name="p1", description="p1", data=np.arange(3)), + dtr_p1, + vi_dtr_p1, + ], + ) # Super-parent table - dtr_sp = DynamicTableRegion(name='sl1', description='sl1', data=np.arange(3), table=p1) - spt = DynamicTable(name='super_parent_table', description='super_parent_table', - columns=[VectorData(name='sp1', description='sp1', data=np.arange(3)), dtr_sp]) - spt.add_column(name='vic', description='vic', data=list(range(9)), index=list([2, 4, 6])) + dtr_sp = DynamicTableRegion(name="sl1", description="sl1", data=np.arange(3), table=p1) + spt = DynamicTable( + name="super_parent_table", + description="super_parent_table", + columns=[ + VectorData(name="sp1", description="sp1", data=np.arange(3)), + dtr_sp, + ], + ) + spt.add_column( + name="vic", + description="vic", + data=list(range(9)), + index=list([2, 4, 6]), + ) hier_df = to_hierarchical_dataframe(spt).reset_index() - expected_columns = [('super_parent_table', 'id'), ('super_parent_table', 'sp1'), ('super_parent_table', 'vic'), - ('parent_table', 'id'), ('parent_table', 'p1'), - ('aligned_table', 'id'), - ('aligned_table', ('aligned_table', 'a1')), ('aligned_table', ('level0_0', 'id')), - ('aligned_table', ('level0_0', 'tags')), ('aligned_table', ('level0_0', 'myid'))] + expected_columns = [ + ("super_parent_table", "id"), + ("super_parent_table", "sp1"), + ("super_parent_table", "vic"), + ("parent_table", "id"), + ("parent_table", "p1"), + ("aligned_table", "id"), + ("aligned_table", ("aligned_table", "a1")), + ("aligned_table", ("level0_0", "id")), + ("aligned_table", ("level0_0", "tags")), + ("aligned_table", ("level0_0", "myid")), + ] self.assertListEqual(hier_df.columns.to_list(), expected_columns) # make sure we have the right columns - self.assertListEqual(hier_df[('aligned_table', ('level0_0', 'tags'))].to_list(), - [['tag1'], ['tag2'], ['tag2', 'tag1']]) + self.assertListEqual( + hier_df[("aligned_table", ("level0_0", "tags"))].to_list(), + [["tag1"], ["tag2"], ["tag2", "tag1"]], + ) def test_to_hierarchical_dataframe_empty_tables(self): # Setup empty tables with the following hierarchy # super_parent_table ---> parent_table ---> child_table - a1 = DynamicTable(name='level0_0', description="level0_0 DynamicTable", - columns=[VectorData(name='l0', description='l0', data=[])]) - p1 = DynamicTable(name='parent_table', description='parent_table', - columns=[DynamicTableRegion(name='l1', description='l1', data=[], table=a1), - VectorData(name='p1c', description='l0', data=[])]) - dtr_sp = DynamicTableRegion(name='sl1', description='sl1', data=np.arange(4), table=p1) - vi_dtr_sp = VectorIndex(name='sl1_index', data=[], target=dtr_sp) - spt = DynamicTable(name='super_parent_table', description='super_parent_table', - columns=[dtr_sp, vi_dtr_sp, VectorData(name='sptc', description='l0', data=[])]) + a1 = DynamicTable( + name="level0_0", + description="level0_0 DynamicTable", + columns=[VectorData(name="l0", description="l0", data=[])], + ) + p1 = DynamicTable( + name="parent_table", + description="parent_table", + columns=[ + DynamicTableRegion(name="l1", description="l1", data=[], table=a1), + VectorData(name="p1c", description="l0", data=[]), + ], + ) + dtr_sp = DynamicTableRegion(name="sl1", description="sl1", data=np.arange(4), table=p1) + vi_dtr_sp = VectorIndex(name="sl1_index", data=[], target=dtr_sp) + spt = DynamicTable( + name="super_parent_table", + description="super_parent_table", + columns=[ + dtr_sp, + vi_dtr_sp, + VectorData(name="sptc", description="l0", data=[]), + ], + ) # Convert to hierarchical dataframe and make sure we get the right columns hier_df = to_hierarchical_dataframe(spt).reset_index() - expected_columns = [('super_parent_table', 'id'), ('super_parent_table', 'sptc'), - ('parent_table', 'id'), ('parent_table', 'p1c'), - ('level0_0', 'id'), ('level0_0', 'l0')] + expected_columns = [ + ("super_parent_table", "id"), + ("super_parent_table", "sptc"), + ("parent_table", "id"), + ("parent_table", "p1c"), + ("level0_0", "id"), + ("level0_0", "l0"), + ] self.assertListEqual(hier_df.columns.to_list(), expected_columns) def test_to_hierarchical_dataframe_multilevel(self): hier_df = to_hierarchical_dataframe(self.super_parent_table).reset_index() - expected_cols = [('super_parent_table', 'id'), ('super_parent_table', 'sp1'), - ('parent_table', 'id'), ('parent_table', 'p1'), - ('aligned_table', 'id'), - ('aligned_table', ('aligned_table', 'a1')), - ('aligned_table', ('level0_0', 'id')), - ('aligned_table', ('level0_0', 'tags')), - ('aligned_table', ('level0_0', 'myid'))] + expected_cols = [ + ("super_parent_table", "id"), + ("super_parent_table", "sp1"), + ("parent_table", "id"), + ("parent_table", "p1"), + ("aligned_table", "id"), + ("aligned_table", ("aligned_table", "a1")), + ("aligned_table", ("level0_0", "id")), + ("aligned_table", ("level0_0", "tags")), + ("aligned_table", ("level0_0", "myid")), + ] # Check that we have all the columns self.assertListEqual(hier_df.columns.to_list(), expected_cols) # Spot-check the data in two columns - self.assertListEqual(hier_df[('aligned_table', ('level0_0', 'tags'))].to_list(), - [['tag1'], ['tag2'], ['tag2', 'tag1']]) - self.assertListEqual(hier_df[('aligned_table', ('aligned_table', 'a1'))].to_list(), list(range(3))) + self.assertListEqual( + hier_df[("aligned_table", ("level0_0", "tags"))].to_list(), + [["tag1"], ["tag2"], ["tag2", "tag1"]], + ) + self.assertListEqual( + hier_df[("aligned_table", ("aligned_table", "a1"))].to_list(), + list(range(3)), + ) def test_to_hierarchical_dataframe(self): hier_df = to_hierarchical_dataframe(self.parent_table) self.assertEqual(len(hier_df), 4) self.assertEqual(len(hier_df.columns), 5) self.assertEqual(len(hier_df.index.names), 2) - columns = [('aligned_table', 'id'), - ('aligned_table', ('aligned_table', 'a1')), - ('aligned_table', ('level0_0', 'id')), - ('aligned_table', ('level0_0', 'tags')), - ('aligned_table', ('level0_0', 'myid'))] + columns = [ + ("aligned_table", "id"), + ("aligned_table", ("aligned_table", "a1")), + ("aligned_table", ("level0_0", "id")), + ("aligned_table", ("level0_0", "tags")), + ("aligned_table", ("level0_0", "myid")), + ] for i, c in enumerate(hier_df.columns): self.assertTupleEqual(c, columns[i]) - index_names = [('parent_table', 'id'), ('parent_table', 'p1')] + index_names = [("parent_table", "id"), ("parent_table", "p1")] self.assertListEqual(hier_df.index.names, index_names) self.assertListEqual(hier_df.index.to_list(), [(i, i) for i in range(4)]) - self.assertListEqual(hier_df[('aligned_table', ('aligned_table', 'a1'))].to_list(), list(range(4))) - self.assertListEqual(hier_df[('aligned_table', ('level0_0', 'id'))].to_list(), list(range(10, 14))) - self.assertListEqual(hier_df[('aligned_table', ('level0_0', 'myid'))].to_list(), list(range(4))) - tags = [['tag1'], ['tag2'], ['tag2', 'tag1'], ['tag3', 'tag4', 'tag5']] - for i, v in enumerate(hier_df[('aligned_table', ('level0_0', 'tags'))].to_list()): + self.assertListEqual( + hier_df[("aligned_table", ("aligned_table", "a1"))].to_list(), + list(range(4)), + ) + self.assertListEqual( + hier_df[("aligned_table", ("level0_0", "id"))].to_list(), + list(range(10, 14)), + ) + self.assertListEqual( + hier_df[("aligned_table", ("level0_0", "myid"))].to_list(), + list(range(4)), + ) + tags = [["tag1"], ["tag2"], ["tag2", "tag1"], ["tag3", "tag4", "tag5"]] + for i, v in enumerate(hier_df[("aligned_table", ("level0_0", "tags"))].to_list()): self.assertListEqual(v, tags[i]) def test_to_hierarchical_dataframe_flat_table(self): @@ -560,32 +966,39 @@ def test_drop_id_columns(self): hier_df = to_hierarchical_dataframe(self.parent_table) cols = hier_df.columns.to_list() mod_df = drop_id_columns(hier_df, inplace=False) - expected_cols = [('aligned_table', ('aligned_table', 'a1')), - ('aligned_table', ('level0_0', 'tags')), - ('aligned_table', ('level0_0', 'myid'))] + expected_cols = [ + ("aligned_table", ("aligned_table", "a1")), + ("aligned_table", ("level0_0", "tags")), + ("aligned_table", ("level0_0", "myid")), + ] self.assertListEqual(hier_df.columns.to_list(), cols) # Test that no columns are dropped with inplace=False - self.assertListEqual(mod_df.columns.to_list(), expected_cols) # Assert that we got back a modified dataframe + self.assertListEqual(mod_df.columns.to_list(), expected_cols) # Assert that we got back a modified dataframe drop_id_columns(hier_df, inplace=True) - self.assertListEqual(hier_df.columns.to_list(), - expected_cols) + self.assertListEqual(hier_df.columns.to_list(), expected_cols) flat_df = to_hierarchical_dataframe(self.parent_table).reset_index(inplace=False) drop_id_columns(flat_df, inplace=True) - self.assertListEqual(flat_df.columns.to_list(), - [('parent_table', 'p1'), - ('aligned_table', ('aligned_table', 'a1')), - ('aligned_table', ('level0_0', 'tags')), - ('aligned_table', ('level0_0', 'myid'))]) + self.assertListEqual( + flat_df.columns.to_list(), + [ + ("parent_table", "p1"), + ("aligned_table", ("aligned_table", "a1")), + ("aligned_table", ("level0_0", "tags")), + ("aligned_table", ("level0_0", "myid")), + ], + ) def test_flatten_column_index(self): hier_df = to_hierarchical_dataframe(self.parent_table).reset_index() cols = hier_df.columns.to_list() - expexted_cols = [('parent_table', 'id'), - ('parent_table', 'p1'), - ('aligned_table', 'id'), - ('aligned_table', 'aligned_table', 'a1'), - ('aligned_table', 'level0_0', 'id'), - ('aligned_table', 'level0_0', 'tags'), - ('aligned_table', 'level0_0', 'myid')] + expexted_cols = [ + ("parent_table", "id"), + ("parent_table", "p1"), + ("aligned_table", "id"), + ("aligned_table", "aligned_table", "a1"), + ("aligned_table", "level0_0", "id"), + ("aligned_table", "level0_0", "tags"), + ("aligned_table", "level0_0", "myid"), + ] df = flatten_column_index(hier_df, inplace=False) # Test that our columns have not changed with inplace=False self.assertListEqual(hier_df.columns.to_list(), cols) @@ -594,19 +1007,26 @@ def test_flatten_column_index(self): self.assertListEqual(hier_df.columns.to_list(), expexted_cols) # Test that we can apply flatten_column_index again on our already modified dataframe to reduce the levels flatten_column_index(hier_df, inplace=True, max_levels=2) - expexted_cols = [('parent_table', 'id'), ('parent_table', 'p1'), ('aligned_table', 'id'), - ('aligned_table', 'a1'), ('level0_0', 'id'), ('level0_0', 'tags'), ('level0_0', 'myid')] + expexted_cols = [ + ("parent_table", "id"), + ("parent_table", "p1"), + ("aligned_table", "id"), + ("aligned_table", "a1"), + ("level0_0", "id"), + ("level0_0", "tags"), + ("level0_0", "myid"), + ] self.assertListEqual(hier_df.columns.to_list(), expexted_cols) # Test that we can directly reduce the max_levels to just 1 hier_df = to_hierarchical_dataframe(self.parent_table).reset_index() flatten_column_index(hier_df, inplace=True, max_levels=1) - expexted_cols = ['id', 'p1', 'id', 'a1', 'id', 'tags', 'myid'] + expexted_cols = ["id", "p1", "id", "a1", "id", "tags", "myid"] self.assertListEqual(hier_df.columns.to_list(), expexted_cols) def test_flatten_column_index_already_flat_index(self): hier_df = to_hierarchical_dataframe(self.parent_table).reset_index() flatten_column_index(hier_df, inplace=True, max_levels=1) - expexted_cols = ['id', 'p1', 'id', 'a1', 'id', 'tags', 'myid'] + expexted_cols = ["id", "p1", "id", "a1", "id", "tags", "myid"] self.assertListEqual(hier_df.columns.to_list(), expexted_cols) # Now try to flatten the already flat columns again to make sure nothing changes flatten_column_index(hier_df, inplace=True, max_levels=1) @@ -614,9 +1034,9 @@ def test_flatten_column_index_already_flat_index(self): def test_flatten_column_index_bad_maxlevels(self): hier_df = to_hierarchical_dataframe(self.parent_table) - with self.assertRaisesWith(ValueError, 'max_levels must be greater than 0'): + with self.assertRaisesWith(ValueError, "max_levels must be greater than 0"): flatten_column_index(dataframe=hier_df, inplace=True, max_levels=-1) - with self.assertRaisesWith(ValueError, 'max_levels must be greater than 0'): + with self.assertRaisesWith(ValueError, "max_levels must be greater than 0"): flatten_column_index(dataframe=hier_df, inplace=True, max_levels=0) @@ -629,6 +1049,7 @@ class TestLinkedDynamicTables(TestCase): we test with container class. The only time I/O becomes relevant is on read in case that, e.g., a h5py.Dataset may behave differently than a numpy array. """ + def setUp(self): """ Create basic set of linked tables consisting of @@ -638,12 +1059,14 @@ def setUp(self): ------> table_level_0_1 """ - self.table_level0_0 = DynamicTable(name='level0_0', description="level0_0 DynamicTable") - self.table_level0_1 = DynamicTable(name='level0_1', description="level0_1 DynamicTable") - self.table_level1 = DynamicTableMultiDTR(name='level1', - child_table1=self.table_level0_0, - child_table2=self.table_level0_1) - self.table_level2 = DynamicTableSingleDTR(name='level2', child_table1=self.table_level1) + self.table_level0_0 = DynamicTable(name="level0_0", description="level0_0 DynamicTable") + self.table_level0_1 = DynamicTable(name="level0_1", description="level0_1 DynamicTable") + self.table_level1 = DynamicTableMultiDTR( + name="level1", + child_table1=self.table_level0_0, + child_table2=self.table_level0_1, + ) + self.table_level2 = DynamicTableSingleDTR(name="level2", child_table1=self.table_level1) def tearDown(self): del self.table_level0_0 @@ -658,80 +1081,106 @@ def popolate_tables(self): self.table_level0_0.add_row(id=11) self.table_level0_0.add_row(id=12) self.table_level0_0.add_row(id=13) - self.table_level0_0.add_column(data=['tag1', 'tag2', 'tag2', 'tag1', 'tag3', 'tag4', 'tag5'], - name='tags', - description='custom tags', - index=[1, 2, 4, 7]) - self.table_level0_0.add_column(data=np.arange(4), - name='myid', - description='custom ids', - index=False) + self.table_level0_0.add_column( + data=["tag1", "tag2", "tag2", "tag1", "tag3", "tag4", "tag5"], + name="tags", + description="custom tags", + index=[1, 2, 4, 7], + ) + self.table_level0_0.add_column( + data=np.arange(4), + name="myid", + description="custom ids", + index=False, + ) # Level 0 1 table. I.e., second table on level 0 self.table_level0_1.add_row(id=14) self.table_level0_1.add_row(id=15) self.table_level0_1.add_row(id=16) self.table_level0_1.add_row(id=17) - self.table_level0_1.add_column(data=['tag1', 'tag1', 'tag2', 'tag2', 'tag3', 'tag3', 'tag4'], - name='tags', - description='custom tags', - index=[2, 4, 6, 7]) - self.table_level0_1.add_column(data=np.arange(4), - name='myid', - description='custom ids', - index=False) + self.table_level0_1.add_column( + data=["tag1", "tag1", "tag2", "tag2", "tag3", "tag3", "tag4"], + name="tags", + description="custom tags", + index=[2, 4, 6, 7], + ) + self.table_level0_1.add_column( + data=np.arange(4), + name="myid", + description="custom ids", + index=False, + ) # Level 1 table self.table_level1.add_row(id=0, child_table_ref1=[0, 1], child_table_ref2=[0]) self.table_level1.add_row(id=1, child_table_ref1=[2], child_table_ref2=[1, 2]) self.table_level1.add_row(id=2, child_table_ref1=[3], child_table_ref2=[3]) - self.table_level1.add_column(data=['tag1', 'tag2', 'tag2'], - name='tag', - description='custom tag', - index=False) - self.table_level1.add_column(data=['tag1', 'tag2', 'tag2', 'tag3', 'tag3', 'tag4', 'tag5'], - name='tags', - description='custom tags', - index=[2, 4, 7]) + self.table_level1.add_column( + data=["tag1", "tag2", "tag2"], + name="tag", + description="custom tag", + index=False, + ) + self.table_level1.add_column( + data=["tag1", "tag2", "tag2", "tag3", "tag3", "tag4", "tag5"], + name="tags", + description="custom tags", + index=[2, 4, 7], + ) # Level 2 data - self.table_level2.add_row(id=0, child_table_ref1=[0, ]) + self.table_level2.add_row( + id=0, + child_table_ref1=[ + 0, + ], + ) self.table_level2.add_row(id=1, child_table_ref1=[1, 2]) - self.table_level2.add_column(data=[10, 12], - name='filter', - description='filter value', - index=False) + self.table_level2.add_column( + data=[10, 12], + name="filter", + description="filter value", + index=False, + ) def test_populate_table_hierarchy(self): """Test that just checks that populating the tables with data works correctly""" self.popolate_tables() # Check level0 0 data self.assertListEqual(self.table_level0_0.id[:], np.arange(10, 14, 1).tolist()) - self.assertListEqual(self.table_level0_0['tags'][:], - [['tag1'], ['tag2'], ['tag2', 'tag1'], ['tag3', 'tag4', 'tag5']]) - self.assertListEqual(self.table_level0_0['myid'][:].tolist(), np.arange(0, 4, 1).tolist()) + self.assertListEqual( + self.table_level0_0["tags"][:], + [["tag1"], ["tag2"], ["tag2", "tag1"], ["tag3", "tag4", "tag5"]], + ) + self.assertListEqual(self.table_level0_0["myid"][:].tolist(), np.arange(0, 4, 1).tolist()) # Check level0 1 data self.assertListEqual(self.table_level0_1.id[:], np.arange(14, 18, 1).tolist()) - self.assertListEqual(self.table_level0_1['tags'][:], - [['tag1', 'tag1'], ['tag2', 'tag2'], ['tag3', 'tag3'], ['tag4']]) - self.assertListEqual(self.table_level0_1['myid'][:].tolist(), np.arange(0, 4, 1).tolist()) + self.assertListEqual( + self.table_level0_1["tags"][:], + [["tag1", "tag1"], ["tag2", "tag2"], ["tag3", "tag3"], ["tag4"]], + ) + self.assertListEqual(self.table_level0_1["myid"][:].tolist(), np.arange(0, 4, 1).tolist()) # Check level1 data self.assertListEqual(self.table_level1.id[:], np.arange(0, 3, 1).tolist()) - self.assertListEqual(self.table_level1['tag'][:], ['tag1', 'tag2', 'tag2']) - self.assertTrue(self.table_level1['child_table_ref1'].target.table is self.table_level0_0) - self.assertTrue(self.table_level1['child_table_ref2'].target.table is self.table_level0_1) - self.assertEqual(len(self.table_level1['child_table_ref1'].target.table), 4) - self.assertEqual(len(self.table_level1['child_table_ref2'].target.table), 4) + self.assertListEqual(self.table_level1["tag"][:], ["tag1", "tag2", "tag2"]) + self.assertTrue(self.table_level1["child_table_ref1"].target.table is self.table_level0_0) + self.assertTrue(self.table_level1["child_table_ref2"].target.table is self.table_level0_1) + self.assertEqual(len(self.table_level1["child_table_ref1"].target.table), 4) + self.assertEqual(len(self.table_level1["child_table_ref2"].target.table), 4) # Check level2 data self.assertListEqual(self.table_level2.id[:], np.arange(0, 2, 1).tolist()) - self.assertListEqual(self.table_level2['filter'][:], [10, 12]) - self.assertTrue(self.table_level2['child_table_ref1'].target.table is self.table_level1) - self.assertEqual(len(self.table_level2['child_table_ref1'].target.table), 3) + self.assertListEqual(self.table_level2["filter"][:], [10, 12]) + self.assertTrue(self.table_level2["child_table_ref1"].target.table is self.table_level1) + self.assertEqual(len(self.table_level2["child_table_ref1"].target.table), 3) def test_get_foreign_columns(self): """Test DynamicTable.get_foreign_columns""" self.popolate_tables() self.assertListEqual(self.table_level0_0.get_foreign_columns(), []) self.assertListEqual(self.table_level0_1.get_foreign_columns(), []) - self.assertListEqual(self.table_level1.get_foreign_columns(), ['child_table_ref1', 'child_table_ref2']) - self.assertListEqual(self.table_level2.get_foreign_columns(), ['child_table_ref1']) + self.assertListEqual( + self.table_level1.get_foreign_columns(), + ["child_table_ref1", "child_table_ref2"], + ) + self.assertListEqual(self.table_level2.get_foreign_columns(), ["child_table_ref1"]) def test_has_foreign_columns(self): """Test DynamicTable.get_foreign_columns""" @@ -752,20 +1201,20 @@ def test_get_linked_tables(self): temp = self.table_level1.get_linked_tables() self.assertEqual(len(temp), 2) self.assertEqual(temp[0].source_table.name, self.table_level1.name) - self.assertEqual(temp[0].source_column.name, 'child_table_ref1') + self.assertEqual(temp[0].source_column.name, "child_table_ref1") self.assertEqual(temp[0].target_table.name, self.table_level0_0.name) self.assertEqual(temp[1].source_table.name, self.table_level1.name) - self.assertEqual(temp[1].source_column.name, 'child_table_ref2') + self.assertEqual(temp[1].source_column.name, "child_table_ref2") self.assertEqual(temp[1].target_table.name, self.table_level0_1.name) # check level2 temp = self.table_level2.get_linked_tables() self.assertEqual(len(temp), 3) self.assertEqual(temp[0].source_table.name, self.table_level2.name) - self.assertEqual(temp[0].source_column.name, 'child_table_ref1') + self.assertEqual(temp[0].source_column.name, "child_table_ref1") self.assertEqual(temp[0].target_table.name, self.table_level1.name) self.assertEqual(temp[1].source_table.name, self.table_level1.name) - self.assertEqual(temp[1].source_column.name, 'child_table_ref1') + self.assertEqual(temp[1].source_column.name, "child_table_ref1") self.assertEqual(temp[1].target_table.name, self.table_level0_0.name) self.assertEqual(temp[2].source_table.name, self.table_level1.name) - self.assertEqual(temp[2].source_column.name, 'child_table_ref2') + self.assertEqual(temp[2].source_column.name, "child_table_ref2") self.assertEqual(temp[2].target_table.name, self.table_level0_1.name) diff --git a/tests/unit/common/test_multi.py b/tests/unit/common/test_multi.py index b466bd733..2fc1bd44c 100644 --- a/tests/unit/common/test_multi.py +++ b/tests/unit/common/test_multi.py @@ -1,16 +1,15 @@ from hdmf.common import SimpleMultiContainer from hdmf.container import Container, Data -from hdmf.testing import TestCase, H5RoundTripMixin +from hdmf.testing import H5RoundTripMixin, TestCase class SimpleMultiContainerRoundTrip(H5RoundTripMixin, TestCase): - def setUpContainer(self): containers = [ - Container('container1'), - Container('container2'), - Data('data1', [0, 1, 2, 3, 4]), - Data('data2', [0.0, 1.0, 2.0, 3.0, 4.0]), + Container("container1"), + Container("container2"), + Data("data1", [0, 1, 2, 3, 4]), + Data("data2", [0.0, 1.0, 2.0, 3.0, 4.0]), ] - multi_container = SimpleMultiContainer(name='multi', containers=containers) + multi_container = SimpleMultiContainer(name="multi", containers=containers) return multi_container diff --git a/tests/unit/common/test_resources.py b/tests/unit/common/test_resources.py index a278ad1a8..9900489e7 100644 --- a/tests/unit/common/test_resources.py +++ b/tests/unit/common/test_resources.py @@ -1,46 +1,39 @@ +import numpy as np import pandas as pd + +from hdmf import Container, Data, ExternalResourcesManager from hdmf.common import DynamicTable from hdmf.common.resources import ExternalResources, Key -from hdmf import Data, Container, ExternalResourcesManager -from hdmf.testing import TestCase, H5RoundTripMixin, remove_test_file -import numpy as np +from hdmf.spec import AttributeSpec, DatasetSpec, GroupSpec +from hdmf.testing import H5RoundTripMixin, TestCase, remove_test_file from tests.unit.build_tests.test_io_map import Bar -from tests.unit.helpers.utils import create_test_type_map, CORE_NAMESPACE -from hdmf.spec import GroupSpec, AttributeSpec, DatasetSpec + +from ..helpers.utils import CORE_NAMESPACE, create_test_type_map class ExternalResourcesManagerContainer(Container, ExternalResourcesManager): def __init__(self, **kwargs): - kwargs['name'] = 'ExternalResourcesManagerContainer' + kwargs["name"] = "ExternalResourcesManagerContainer" super().__init__(**kwargs) class TestExternalResources(H5RoundTripMixin, TestCase): - def setUpContainer(self): er = ExternalResources() - file = ExternalResourcesManagerContainer(name='file') - file2 = ExternalResourcesManagerContainer(name='file2') - er.add_ref(file=file, - container=file, - key='special', - entity_id="id11", - entity_uri='url11') - er.add_ref(file=file2, - container=file2, - key='key2', - entity_id="id12", - entity_uri='url12') + file = ExternalResourcesManagerContainer(name="file") + file2 = ExternalResourcesManagerContainer(name="file2") + er.add_ref(file=file, container=file, key="special", entity_id="id11", entity_uri="url11") + er.add_ref(file=file2, container=file2, key="key2", entity_id="id12", entity_uri="url12") return er def remove_er_files(self): - remove_test_file('./entities.tsv') - remove_test_file('./objects.tsv') - remove_test_file('./object_keys.tsv') - remove_test_file('./keys.tsv') - remove_test_file('./files.tsv') - remove_test_file('./er.tsv') + remove_test_file("./entities.tsv") + remove_test_file("./objects.tsv") + remove_test_file("./object_keys.tsv") + remove_test_file("./keys.tsv") + remove_test_file("./files.tsv") + remove_test_file("./er.tsv") def test_to_dataframe(self): # Setup complex external resources with keys reused across objects and @@ -48,287 +41,276 @@ def test_to_dataframe(self): er = ExternalResources() # Add a species dataset with 2 keys data1 = Data( - name='data_name', + name="data_name", data=np.array( - [('Mus musculus', 9, 81.0), ('Homo sapiens', 3, 27.0)], - dtype=[('species', 'U14'), ('age', 'i4'), ('weight', 'f4')] - ) + [("Mus musculus", 9, 81.0), ("Homo sapiens", 3, 27.0)], + dtype=[("species", "U14"), ("age", "i4"), ("weight", "f4")], + ), ) - file = ExternalResourcesManagerContainer(name='file') + file = ExternalResourcesManagerContainer(name="file") - ck1, e1 = er.add_ref(file=file, - container=data1, - field='species', - key='Mus musculus', - entity_id='NCBI:txid10090', - entity_uri='https://www.ncbi.nlm.nih.gov/Taxonomy/Browser/wwwtax.cgi?id=10090') - k2, e2 = er.add_ref(file=file, - container=data1, - field='species', - key='Homo sapiens', - entity_id='NCBI:txid9606', - entity_uri='https://www.ncbi.nlm.nih.gov/Taxonomy/Browser/wwwtax.cgi?id=9606') + ck1, e1 = er.add_ref( + file=file, + container=data1, + field="species", + key="Mus musculus", + entity_id="NCBI:txid10090", + entity_uri="https://www.ncbi.nlm.nih.gov/Taxonomy/Browser/wwwtax.cgi?id=10090", + ) + k2, e2 = er.add_ref( + file=file, + container=data1, + field="species", + key="Homo sapiens", + entity_id="NCBI:txid9606", + entity_uri="https://www.ncbi.nlm.nih.gov/Taxonomy/Browser/wwwtax.cgi?id=9606", + ) # Convert to dataframe and compare against the expected result result_df = er.to_dataframe() - expected_df_data = \ - {'file_object_id': {0: file.object_id, 1: file.object_id}, - 'objects_idx': {0: 0, 1: 0}, - 'object_id': {0: data1.object_id, 1: data1.object_id}, - 'files_idx': {0: 0, 1: 0}, - 'object_type': {0: 'Data', 1: 'Data'}, - 'relative_path': {0: '', 1: ''}, - 'field': {0: 'species', 1: 'species'}, - 'keys_idx': {0: 0, 1: 1}, - 'key': {0: 'Mus musculus', 1: 'Homo sapiens'}, - 'entities_idx': {0: 0, 1: 1}, - 'entity_id': {0: 'NCBI:txid10090', 1: 'NCBI:txid9606'}, - 'entity_uri': {0: 'https://www.ncbi.nlm.nih.gov/Taxonomy/Browser/wwwtax.cgi?id=10090', - 1: 'https://www.ncbi.nlm.nih.gov/Taxonomy/Browser/wwwtax.cgi?id=9606'}} + expected_df_data = { + "file_object_id": {0: file.object_id, 1: file.object_id}, + "objects_idx": {0: 0, 1: 0}, + "object_id": {0: data1.object_id, 1: data1.object_id}, + "files_idx": {0: 0, 1: 0}, + "object_type": {0: "Data", 1: "Data"}, + "relative_path": {0: "", 1: ""}, + "field": {0: "species", 1: "species"}, + "keys_idx": {0: 0, 1: 1}, + "key": {0: "Mus musculus", 1: "Homo sapiens"}, + "entities_idx": {0: 0, 1: 1}, + "entity_id": {0: "NCBI:txid10090", 1: "NCBI:txid9606"}, + "entity_uri": { + 0: "https://www.ncbi.nlm.nih.gov/Taxonomy/Browser/wwwtax.cgi?id=10090", + 1: "https://www.ncbi.nlm.nih.gov/Taxonomy/Browser/wwwtax.cgi?id=9606", + }, + } expected_df = pd.DataFrame.from_dict(expected_df_data) - expected_df = expected_df.astype({'keys_idx': 'uint32', - 'objects_idx': 'uint32', - 'files_idx': 'uint32', - 'entities_idx': 'uint32'}) + expected_df = expected_df.astype( + {"keys_idx": "uint32", "objects_idx": "uint32", "files_idx": "uint32", "entities_idx": "uint32"} + ) pd.testing.assert_frame_equal(result_df, expected_df) def test_assert_external_resources_equal(self): - file = ExternalResourcesManagerContainer(name='file') - ref_container_1 = Container(name='Container_1') + file = ExternalResourcesManagerContainer(name="file") + ref_container_1 = Container(name="Container_1") er_left = ExternalResources() - er_left.add_ref(file=file, - container=ref_container_1, - key='key1', - entity_id="id11", - entity_uri='url11') + er_left.add_ref(file=file, container=ref_container_1, key="key1", entity_id="id11", entity_uri="url11") er_right = ExternalResources() - er_right.add_ref(file=file, - container=ref_container_1, - key='key1', - entity_id="id11", - entity_uri='url11') + er_right.add_ref(file=file, container=ref_container_1, key="key1", entity_id="id11", entity_uri="url11") - self.assertTrue(ExternalResources.assert_external_resources_equal(er_left, - er_right)) + self.assertTrue(ExternalResources.assert_external_resources_equal(er_left, er_right)) def test_invalid_keys_assert_external_resources_equal(self): er_left = ExternalResources() - er_left.add_ref(file=ExternalResourcesManagerContainer(name='file'), - container=Container(name='Container'), - key='key1', - entity_id="id11", - entity_uri='url11') + er_left.add_ref( + file=ExternalResourcesManagerContainer(name="file"), + container=Container(name="Container"), + key="key1", + entity_id="id11", + entity_uri="url11", + ) er_right = ExternalResources() - er_right.add_ref(file=ExternalResourcesManagerContainer(name='file'), - container=Container(name='Container'), - key='invalid', - entity_id="id11", - entity_uri='url11') + er_right.add_ref( + file=ExternalResourcesManagerContainer(name="file"), + container=Container(name="Container"), + key="invalid", + entity_id="id11", + entity_uri="url11", + ) with self.assertRaises(AssertionError): - ExternalResources.assert_external_resources_equal(er_left, - er_right) + ExternalResources.assert_external_resources_equal(er_left, er_right) def test_invalid_objects_assert_external_resources_equal(self): er_left = ExternalResources() - er_left.add_ref(file=ExternalResourcesManagerContainer(name='file'), - container=Container(name='Container'), - key='key1', - entity_id="id11", - entity_uri='url11') + er_left.add_ref( + file=ExternalResourcesManagerContainer(name="file"), + container=Container(name="Container"), + key="key1", + entity_id="id11", + entity_uri="url11", + ) er_right = ExternalResources() - er_right.add_ref(file=ExternalResourcesManagerContainer(name='file'), - container=Container(name='Container'), - key='key1', - entity_id="id11", - entity_uri='url11') + er_right.add_ref( + file=ExternalResourcesManagerContainer(name="file"), + container=Container(name="Container"), + key="key1", + entity_id="id11", + entity_uri="url11", + ) with self.assertRaises(AssertionError): - ExternalResources.assert_external_resources_equal(er_left, - er_right) + ExternalResources.assert_external_resources_equal(er_left, er_right) def test_invalid_entity_assert_external_resources_equal(self): er_left = ExternalResources() - er_left.add_ref(file=ExternalResourcesManagerContainer(name='file'), - container=Container(name='Container'), - key='key1', - entity_id="invalid", - entity_uri='invalid') + er_left.add_ref( + file=ExternalResourcesManagerContainer(name="file"), + container=Container(name="Container"), + key="key1", + entity_id="invalid", + entity_uri="invalid", + ) er_right = ExternalResources() - er_right.add_ref(file=ExternalResourcesManagerContainer(name='file'), - container=Container(name='Container'), - key='key1', - entity_id="id11", - entity_uri='url11') + er_right.add_ref( + file=ExternalResourcesManagerContainer(name="file"), + container=Container(name="Container"), + key="key1", + entity_id="id11", + entity_uri="url11", + ) with self.assertRaises(AssertionError): - ExternalResources.assert_external_resources_equal(er_left, - er_right) + ExternalResources.assert_external_resources_equal(er_left, er_right) def test_invalid_object_keys_assert_external_resources_equal(self): er_left = ExternalResources() - er_left.add_ref(file=ExternalResourcesManagerContainer(name='file'), - container=Container(name='Container'), - key='invalid', - entity_id="id11", - entity_uri='url11') + er_left.add_ref( + file=ExternalResourcesManagerContainer(name="file"), + container=Container(name="Container"), + key="invalid", + entity_id="id11", + entity_uri="url11", + ) er_right = ExternalResources() - er_right._add_key('key') - er_right.add_ref(file=ExternalResourcesManagerContainer(name='file'), - container=Container(name='Container'), - key='key1', - entity_id="id11", - entity_uri='url11') + er_right._add_key("key") + er_right.add_ref( + file=ExternalResourcesManagerContainer(name="file"), + container=Container(name="Container"), + key="key1", + entity_id="id11", + entity_uri="url11", + ) with self.assertRaises(AssertionError): - ExternalResources.assert_external_resources_equal(er_left, - er_right) + ExternalResources.assert_external_resources_equal(er_left, er_right) def test_add_ref_search_for_file(self): em = ExternalResourcesManagerContainer() er = ExternalResources() - er.add_ref(container=em, key='key1', - entity_id='entity_id1', entity_uri='entity1') - self.assertEqual(er.keys.data, [('key1',)]) - self.assertEqual(er.entities.data, [(0, 'entity_id1', 'entity1')]) - self.assertEqual(er.objects.data, [(0, em.object_id, 'ExternalResourcesManagerContainer', '', '')]) + er.add_ref(container=em, key="key1", entity_id="entity_id1", entity_uri="entity1") + self.assertEqual(er.keys.data, [("key1",)]) + self.assertEqual(er.entities.data, [(0, "entity_id1", "entity1")]) + self.assertEqual(er.objects.data, [(0, em.object_id, "ExternalResourcesManagerContainer", "", "")]) def test_add_ref_search_for_file_parent(self): em = ExternalResourcesManagerContainer() - child = Container(name='child') + child = Container(name="child") child.parent = em er = ExternalResources() - er.add_ref(container=child, key='key1', - entity_id='entity_id1', entity_uri='entity1') - self.assertEqual(er.keys.data, [('key1',)]) - self.assertEqual(er.entities.data, [(0, 'entity_id1', 'entity1')]) - self.assertEqual(er.objects.data, [(0, child.object_id, 'Container', '', '')]) + er.add_ref(container=child, key="key1", entity_id="entity_id1", entity_uri="entity1") + self.assertEqual(er.keys.data, [("key1",)]) + self.assertEqual(er.entities.data, [(0, "entity_id1", "entity1")]) + self.assertEqual(er.objects.data, [(0, child.object_id, "Container", "", "")]) def test_add_ref_search_for_file_nested_parent(self): em = ExternalResourcesManagerContainer() - nested_child = Container(name='nested_child') - child = Container(name='child') + nested_child = Container(name="nested_child") + child = Container(name="child") nested_child.parent = child child.parent = em er = ExternalResources() - er.add_ref(container=nested_child, key='key1', - entity_id='entity_id1', entity_uri='entity1') - self.assertEqual(er.keys.data, [('key1',)]) - self.assertEqual(er.entities.data, [(0, 'entity_id1', 'entity1')]) - self.assertEqual(er.objects.data, [(0, nested_child.object_id, 'Container', '', '')]) + er.add_ref(container=nested_child, key="key1", entity_id="entity_id1", entity_uri="entity1") + self.assertEqual(er.keys.data, [("key1",)]) + self.assertEqual(er.entities.data, [(0, "entity_id1", "entity1")]) + self.assertEqual(er.objects.data, [(0, nested_child.object_id, "Container", "", "")]) def test_add_ref_search_for_file_error(self): - container = Container(name='container') + container = Container(name="container") er = ExternalResources() with self.assertRaises(ValueError): - er.add_ref(container=container, - key='key1', - entity_id='entity_id1', - entity_uri='entity1') + er.add_ref(container=container, key="key1", entity_id="entity_id1", entity_uri="entity1") def test_add_ref(self): er = ExternalResources() - data = Data(name="species", data=['Homo sapiens', 'Mus musculus']) - er.add_ref(file=ExternalResourcesManagerContainer(name='file'), - container=data, - key='key1', - entity_id='entity_id1', - entity_uri='entity1') - self.assertEqual(er.keys.data, [('key1',)]) - self.assertEqual(er.entities.data, [(0, 'entity_id1', 'entity1')]) - self.assertEqual(er.objects.data, [(0, data.object_id, 'Data', '', '')]) + data = Data(name="species", data=["Homo sapiens", "Mus musculus"]) + er.add_ref( + file=ExternalResourcesManagerContainer(name="file"), + container=data, + key="key1", + entity_id="entity_id1", + entity_uri="entity1", + ) + self.assertEqual(er.keys.data, [("key1",)]) + self.assertEqual(er.entities.data, [(0, "entity_id1", "entity1")]) + self.assertEqual(er.objects.data, [(0, data.object_id, "Data", "", "")]) def test_get_object_type(self): er = ExternalResources() - file = ExternalResourcesManagerContainer(name='file') - data = Data(name="species", data=['Homo sapiens', 'Mus musculus']) - er.add_ref(file=file, - container=data, - key='key1', - entity_id='entity_id1', - entity_uri='entity1') - - df = er.get_object_type(object_type='Data') - - expected_df_data = \ - {'file_object_id': {0: file.object_id}, - 'objects_idx': {0: 0}, - 'object_id': {0: data.object_id}, - 'files_idx': {0: 0}, - 'object_type': {0: 'Data'}, - 'relative_path': {0: ''}, - 'field': {0: ''}, - 'keys_idx': {0: 0}, - 'key': {0: 'key1'}, - 'entities_idx': {0: 0}, - 'entity_id': {0: 'entity_id1'}, - 'entity_uri': {0: 'entity1'}} + file = ExternalResourcesManagerContainer(name="file") + data = Data(name="species", data=["Homo sapiens", "Mus musculus"]) + er.add_ref(file=file, container=data, key="key1", entity_id="entity_id1", entity_uri="entity1") + + df = er.get_object_type(object_type="Data") + + expected_df_data = { + "file_object_id": {0: file.object_id}, + "objects_idx": {0: 0}, + "object_id": {0: data.object_id}, + "files_idx": {0: 0}, + "object_type": {0: "Data"}, + "relative_path": {0: ""}, + "field": {0: ""}, + "keys_idx": {0: 0}, + "key": {0: "key1"}, + "entities_idx": {0: 0}, + "entity_id": {0: "entity_id1"}, + "entity_uri": {0: "entity1"}, + } expected_df = pd.DataFrame.from_dict(expected_df_data) - expected_df = expected_df.astype({'keys_idx': 'uint32', - 'objects_idx': 'uint32', - 'files_idx': 'uint32', - 'entities_idx': 'uint32'}) + expected_df = expected_df.astype( + {"keys_idx": "uint32", "objects_idx": "uint32", "files_idx": "uint32", "entities_idx": "uint32"} + ) pd.testing.assert_frame_equal(df, expected_df) def test_get_object_type_all_instances(self): er = ExternalResources() - file = ExternalResourcesManagerContainer(name='file') - data = Data(name="species", data=['Homo sapiens', 'Mus musculus']) - er.add_ref(file=file, - container=data, - key='key1', - entity_id='entity_id1', - entity_uri='entity1') - - df = er.get_object_type(object_type='Data', all_instances=True) - - expected_df_data = \ - {'file_object_id': {0: file.object_id}, - 'objects_idx': {0: 0}, - 'object_id': {0: data.object_id}, - 'files_idx': {0: 0}, - 'object_type': {0: 'Data'}, - 'relative_path': {0: ''}, - 'field': {0: ''}, - 'keys_idx': {0: 0}, - 'key': {0: 'key1'}, - 'entities_idx': {0: 0}, - 'entity_id': {0: 'entity_id1'}, - 'entity_uri': {0: 'entity1'}} + file = ExternalResourcesManagerContainer(name="file") + data = Data(name="species", data=["Homo sapiens", "Mus musculus"]) + er.add_ref(file=file, container=data, key="key1", entity_id="entity_id1", entity_uri="entity1") + + df = er.get_object_type(object_type="Data", all_instances=True) + + expected_df_data = { + "file_object_id": {0: file.object_id}, + "objects_idx": {0: 0}, + "object_id": {0: data.object_id}, + "files_idx": {0: 0}, + "object_type": {0: "Data"}, + "relative_path": {0: ""}, + "field": {0: ""}, + "keys_idx": {0: 0}, + "key": {0: "key1"}, + "entities_idx": {0: 0}, + "entity_id": {0: "entity_id1"}, + "entity_uri": {0: "entity1"}, + } expected_df = pd.DataFrame.from_dict(expected_df_data) - expected_df = expected_df.astype({'keys_idx': 'uint32', - 'objects_idx': 'uint32', - 'files_idx': 'uint32', - 'entities_idx': 'uint32'}) + expected_df = expected_df.astype( + {"keys_idx": "uint32", "objects_idx": "uint32", "files_idx": "uint32", "entities_idx": "uint32"} + ) pd.testing.assert_frame_equal(df, expected_df) def test_get_entities(self): er = ExternalResources() - data = Data(name="species", data=['Homo sapiens', 'Mus musculus']) - file = ExternalResourcesManagerContainer(name='file') - er.add_ref(file=file, - container=data, - key='key1', - entity_id='entity_id1', - entity_uri='entity1') - - df = er.get_object_entities(file=file, - container=data) - expected_df_data = \ - {'key_names': {0: 'key1'}, - 'entity_id': {0: 'entity_id1'}, - 'entity_uri': {0: 'entity1'}} + data = Data(name="species", data=["Homo sapiens", "Mus musculus"]) + file = ExternalResourcesManagerContainer(name="file") + er.add_ref(file=file, container=data, key="key1", entity_id="entity_id1", entity_uri="entity1") + + df = er.get_object_entities(file=file, container=data) + expected_df_data = {"key_names": {0: "key1"}, "entity_id": {0: "entity_id1"}, "entity_uri": {0: "entity1"}} expected_df = pd.DataFrame.from_dict(expected_df_data) pd.testing.assert_frame_equal(df, expected_df) @@ -336,16 +318,10 @@ def test_get_entities(self): def test_get_entities_file_none_container(self): er = ExternalResources() file = ExternalResourcesManagerContainer() - er.add_ref(container=file, - key='key1', - entity_id='entity_id1', - entity_uri='entity1') + er.add_ref(container=file, key="key1", entity_id="entity_id1", entity_uri="entity1") df = er.get_object_entities(container=file) - expected_df_data = \ - {'key_names': {0: 'key1'}, - 'entity_id': {0: 'entity_id1'}, - 'entity_uri': {0: 'entity1'}} + expected_df_data = {"key_names": {0: "key1"}, "entity_id": {0: "entity_id1"}, "entity_uri": {0: "entity1"}} expected_df = pd.DataFrame.from_dict(expected_df_data) pd.testing.assert_frame_equal(df, expected_df) @@ -353,20 +329,14 @@ def test_get_entities_file_none_container(self): def test_get_entities_file_none_not_container_nested(self): er = ExternalResources() file = ExternalResourcesManagerContainer() - child = Container(name='child') + child = Container(name="child") child.parent = file - er.add_ref(container=child, - key='key1', - entity_id='entity_id1', - entity_uri='entity1') + er.add_ref(container=child, key="key1", entity_id="entity_id1", entity_uri="entity1") df = er.get_object_entities(container=child) - expected_df_data = \ - {'key_names': {0: 'key1'}, - 'entity_id': {0: 'entity_id1'}, - 'entity_uri': {0: 'entity1'}} + expected_df_data = {"key_names": {0: "key1"}, "entity_id": {0: "entity_id1"}, "entity_uri": {0: "entity1"}} expected_df = pd.DataFrame.from_dict(expected_df_data) pd.testing.assert_frame_equal(df, expected_df) @@ -374,171 +344,165 @@ def test_get_entities_file_none_not_container_nested(self): def test_get_entities_file_none_not_container_deep_nested(self): er = ExternalResources() file = ExternalResourcesManagerContainer() - child = Container(name='child') - nested_child = Container(name='nested_child') + child = Container(name="child") + nested_child = Container(name="nested_child") child.parent = file nested_child.parent = child - er.add_ref(container=nested_child, - key='key1', - entity_id='entity_id1', - entity_uri='entity1') + er.add_ref(container=nested_child, key="key1", entity_id="entity_id1", entity_uri="entity1") df = er.get_object_entities(container=nested_child) - expected_df_data = \ - {'key_names': {0: 'key1'}, - 'entity_id': {0: 'entity_id1'}, - 'entity_uri': {0: 'entity1'}} + expected_df_data = {"key_names": {0: "key1"}, "entity_id": {0: "entity_id1"}, "entity_uri": {0: "entity1"}} expected_df = pd.DataFrame.from_dict(expected_df_data) pd.testing.assert_frame_equal(df, expected_df) def test_get_entities_file_none_error(self): er = ExternalResources() - data = Data(name="species", data=['Homo sapiens', 'Mus musculus']) - file = ExternalResourcesManagerContainer(name='file') - er.add_ref(file=file, - container=data, - key='key1', - entity_id='entity_id1', - entity_uri='entity1') + data = Data(name="species", data=["Homo sapiens", "Mus musculus"]) + file = ExternalResourcesManagerContainer(name="file") + er.add_ref(file=file, container=data, key="key1", entity_id="entity_id1", entity_uri="entity1") with self.assertRaises(ValueError): _ = er.get_object_entities(container=data) def test_get_entities_attribute(self): - table = DynamicTable(name='table', description='table') - table.add_column(name='col1', description="column") - table.add_row(id=0, col1='data') - - file = ExternalResourcesManagerContainer(name='file') - - er = ExternalResources() - er.add_ref(file=file, - container=table, - attribute='col1', - key='key1', - entity_id='entity_0', - entity_uri='entity_0_uri') - df = er.get_object_entities(file=file, - container=table, - attribute='col1') - - expected_df_data = \ - {'key_names': {0: 'key1'}, - 'entity_id': {0: 'entity_0'}, - 'entity_uri': {0: 'entity_0_uri'}} + table = DynamicTable(name="table", description="table") + table.add_column(name="col1", description="column") + table.add_row(id=0, col1="data") + + file = ExternalResourcesManagerContainer(name="file") + + er = ExternalResources() + er.add_ref( + file=file, container=table, attribute="col1", key="key1", entity_id="entity_0", entity_uri="entity_0_uri" + ) + df = er.get_object_entities(file=file, container=table, attribute="col1") + + expected_df_data = {"key_names": {0: "key1"}, "entity_id": {0: "entity_0"}, "entity_uri": {0: "entity_0_uri"}} expected_df = pd.DataFrame.from_dict(expected_df_data) pd.testing.assert_frame_equal(df, expected_df) def test_to_and_from_norm_tsv(self): er = ExternalResources() - data = Data(name="species", data=['Homo sapiens', 'Mus musculus']) - er.add_ref(file=ExternalResourcesManagerContainer(name='file'), - container=data, - key='key1', - entity_id='entity_id1', - entity_uri='entity1') - er.to_norm_tsv(path='./') + data = Data(name="species", data=["Homo sapiens", "Mus musculus"]) + er.add_ref( + file=ExternalResourcesManagerContainer(name="file"), + container=data, + key="key1", + entity_id="entity_id1", + entity_uri="entity1", + ) + er.to_norm_tsv(path="./") - er_read = ExternalResources.from_norm_tsv(path='./') + er_read = ExternalResources.from_norm_tsv(path="./") ExternalResources.assert_external_resources_equal(er_read, er, check_dtype=False) self.remove_er_files() def test_to_and_from_norm_tsv_entity_value_error(self): er = ExternalResources() - data = Data(name="species", data=['Homo sapiens', 'Mus musculus']) - er.add_ref(file=ExternalResourcesManagerContainer(name='file'), - container=data, - key='key1', - entity_id='entity_id1', - entity_uri='entity1') - er.to_norm_tsv(path='./') + data = Data(name="species", data=["Homo sapiens", "Mus musculus"]) + er.add_ref( + file=ExternalResourcesManagerContainer(name="file"), + container=data, + key="key1", + entity_id="entity_id1", + entity_uri="entity1", + ) + er.to_norm_tsv(path="./") df = er.entities.to_dataframe() - df.at[0, ('keys_idx')] = 10 # Change key_ix 0 to 10 - df.to_csv('./entities.tsv', sep='\t', index=False) + df.at[0, "keys_idx"] = 10 # Change key_ix 0 to 10 # Change key_ix 0 to 10 + df.to_csv("./entities.tsv", sep="\t", index=False) msg = "Key Index out of range in EntityTable. Please check for alterations." with self.assertRaisesWith(ValueError, msg): - _ = ExternalResources.from_norm_tsv(path='./') + _ = ExternalResources.from_norm_tsv(path="./") self.remove_er_files() def test_to_and_from_norm_tsv_object_value_error(self): er = ExternalResources() - data = Data(name="species", data=['Homo sapiens', 'Mus musculus']) - er.add_ref(file=ExternalResourcesManagerContainer(name='file'), - container=data, - key='key1', - entity_id='entity_id1', - entity_uri='entity1') - er.to_norm_tsv(path='./') + data = Data(name="species", data=["Homo sapiens", "Mus musculus"]) + er.add_ref( + file=ExternalResourcesManagerContainer(name="file"), + container=data, + key="key1", + entity_id="entity_id1", + entity_uri="entity1", + ) + er.to_norm_tsv(path="./") df = er.objects.to_dataframe() - df.at[0, ('files_idx')] = 10 # Change key_ix 0 to 10 - df.to_csv('./objects.tsv', sep='\t', index=False) + df.at[0, "files_idx"] = 10 # Change key_ix 0 to 10 # Change key_ix 0 to 10 + df.to_csv("./objects.tsv", sep="\t", index=False) msg = "File_ID Index out of range in ObjectTable. Please check for alterations." with self.assertRaisesWith(ValueError, msg): - _ = ExternalResources.from_norm_tsv(path='./') + _ = ExternalResources.from_norm_tsv(path="./") self.remove_er_files() def test_to_and_from_norm_tsv_object_keys_object_idx_value_error(self): er = ExternalResources() - data = Data(name="species", data=['Homo sapiens', 'Mus musculus']) - er.add_ref(file=ExternalResourcesManagerContainer(name='file'), - container=data, - key='key1', - entity_id='entity_id1', - entity_uri='entity1') - er.to_norm_tsv(path='./') + data = Data(name="species", data=["Homo sapiens", "Mus musculus"]) + er.add_ref( + file=ExternalResourcesManagerContainer(name="file"), + container=data, + key="key1", + entity_id="entity_id1", + entity_uri="entity1", + ) + er.to_norm_tsv(path="./") df = er.object_keys.to_dataframe() - df.at[0, ('objects_idx')] = 10 # Change key_ix 0 to 10 - df.to_csv('./object_keys.tsv', sep='\t', index=False) + df.at[0, "objects_idx"] = 10 # Change key_ix 0 to 10 # Change key_ix 0 to 10 + df.to_csv("./object_keys.tsv", sep="\t", index=False) msg = "Object Index out of range in ObjectKeyTable. Please check for alterations." with self.assertRaisesWith(ValueError, msg): - _ = ExternalResources.from_norm_tsv(path='./') + _ = ExternalResources.from_norm_tsv(path="./") self.remove_er_files() def test_to_and_from_norm_tsv_object_keys_key_idx_value_error(self): er = ExternalResources() - data = Data(name="species", data=['Homo sapiens', 'Mus musculus']) - er.add_ref(file=ExternalResourcesManagerContainer(name='file'), - container=data, - key='key1', - entity_id='entity_id1', - entity_uri='entity1') - er.to_norm_tsv(path='./') + data = Data(name="species", data=["Homo sapiens", "Mus musculus"]) + er.add_ref( + file=ExternalResourcesManagerContainer(name="file"), + container=data, + key="key1", + entity_id="entity_id1", + entity_uri="entity1", + ) + er.to_norm_tsv(path="./") df = er.object_keys.to_dataframe() - df.at[0, ('keys_idx')] = 10 # Change key_ix 0 to 10 - df.to_csv('./object_keys.tsv', sep='\t', index=False) + df.at[0, "keys_idx"] = 10 # Change key_ix 0 to 10 # Change key_ix 0 to 10 + df.to_csv("./object_keys.tsv", sep="\t", index=False) msg = "Key Index out of range in ObjectKeyTable. Please check for alterations." with self.assertRaisesWith(ValueError, msg): - _ = ExternalResources.from_norm_tsv(path='./') + _ = ExternalResources.from_norm_tsv(path="./") self.remove_er_files() def test_to_flat_tsv_and_from_flat_tsv(self): # write er to file er = ExternalResources() - data = Data(name="species", data=['Homo sapiens', 'Mus musculus']) - er.add_ref(file=ExternalResourcesManagerContainer(name='file'), - container=data, - key='key1', - entity_id='entity_id1', - entity_uri='entity1') - er.to_flat_tsv(path='./er.tsv') + data = Data(name="species", data=["Homo sapiens", "Mus musculus"]) + er.add_ref( + file=ExternalResourcesManagerContainer(name="file"), + container=data, + key="key1", + entity_id="entity_id1", + entity_uri="entity1", + ) + er.to_flat_tsv(path="./er.tsv") # read er back from file and compare - er_obj = ExternalResources.from_flat_tsv(path='./er.tsv') + er_obj = ExternalResources.from_flat_tsv(path="./er.tsv") # Check that the data is correct ExternalResources.assert_external_resources_equal(er_obj, er, check_dtype=False) self.remove_er_files() @@ -546,8 +510,8 @@ def test_to_flat_tsv_and_from_flat_tsv(self): def test_to_flat_tsv_and_from_flat_tsv_missing_keyidx(self): # write er to file df = self.container.to_dataframe(use_categories=True) - df.at[0, ('keys', 'keys_idx')] = 10 # Change key_ix 0 to 10 - df.to_csv(self.export_filename, sep='\t') + df.at[0, ("keys", "keys_idx")] = 10 # Change key_ix 0 to 10 + df.to_csv(self.export_filename, sep="\t") # read er back from file and compare msg = "Missing keys_idx entries [0, 2, 3, 4, 5, 6, 7, 8, 9]" with self.assertRaisesWith(ValueError, msg): @@ -556,8 +520,8 @@ def test_to_flat_tsv_and_from_flat_tsv_missing_keyidx(self): def test_to_flat_tsv_and_from_flat_tsv_missing_objectidx(self): # write er to file df = self.container.to_dataframe(use_categories=True) - df.at[0, ('objects', 'objects_idx')] = 10 # Change objects_idx 0 to 10 - df.to_csv(self.export_filename, sep='\t') + df.at[0, ("objects", "objects_idx")] = 10 # Change objects_idx 0 to 10 + df.to_csv(self.export_filename, sep="\t") # read er back from file and compare msg = "Missing objects_idx entries [0, 2, 3, 4, 5, 6, 7, 8, 9]" with self.assertRaisesWith(ValueError, msg): @@ -566,8 +530,8 @@ def test_to_flat_tsv_and_from_flat_tsv_missing_objectidx(self): def test_to_flat_tsv_and_from_flat_tsv_missing_entitiesidx(self): # write er to file er_df = self.container.to_dataframe(use_categories=True) - er_df.at[0, ('entities', 'entities_idx')] = 10 # Change entities_idx 0 to 10 - er_df.to_csv(self.export_filename, sep='\t') + er_df.at[0, ("entities", "entities_idx")] = 10 # Change entities_idx 0 to 10 + er_df.to_csv(self.export_filename, sep="\t") # read er back from file and compare msg = "Missing entities_idx entries [0, 2, 3, 4, 5, 6, 7, 8, 9]" with self.assertRaisesWith(ValueError, msg): @@ -575,196 +539,231 @@ def test_to_flat_tsv_and_from_flat_tsv_missing_entitiesidx(self): def test_add_ref_two_keys(self): er = ExternalResources() - ref_container_1 = Container(name='Container_1') - ref_container_2 = Container(name='Container_2') - er.add_ref(file=ExternalResourcesManagerContainer(name='file'), - container=ref_container_1, - key='key1', - entity_id="id11", - entity_uri='url11') - er.add_ref(file=ExternalResourcesManagerContainer(name='file'), - container=ref_container_2, - key='key2', - entity_id="id12", - entity_uri='url21') - - self.assertEqual(er.keys.data, [('key1',), ('key2',)]) - self.assertEqual(er.entities.data, [(0, 'id11', 'url11'), (1, 'id12', 'url21')]) - - self.assertEqual(er.objects.data, [(0, ref_container_1.object_id, 'Container', '', ''), - (1, ref_container_2.object_id, 'Container', '', '')]) + ref_container_1 = Container(name="Container_1") + ref_container_2 = Container(name="Container_2") + er.add_ref( + file=ExternalResourcesManagerContainer(name="file"), + container=ref_container_1, + key="key1", + entity_id="id11", + entity_uri="url11", + ) + er.add_ref( + file=ExternalResourcesManagerContainer(name="file"), + container=ref_container_2, + key="key2", + entity_id="id12", + entity_uri="url21", + ) + + self.assertEqual(er.keys.data, [("key1",), ("key2",)]) + self.assertEqual(er.entities.data, [(0, "id11", "url11"), (1, "id12", "url21")]) + + self.assertEqual( + er.objects.data, + [(0, ref_container_1.object_id, "Container", "", ""), (1, ref_container_2.object_id, "Container", "", "")], + ) def test_add_ref_same_key_diff_objfield(self): er = ExternalResources() - ref_container_1 = Container(name='Container_1') - ref_container_2 = Container(name='Container_2') - er.add_ref(file=ExternalResourcesManagerContainer(name='file'), - container=ref_container_1, - key='key1', - entity_id="id11", - entity_uri='url11') - er.add_ref(file=ExternalResourcesManagerContainer(name='file'), - container=ref_container_2, - key='key1', - entity_id="id12", - entity_uri='url21') - - self.assertEqual(er.keys.data, [('key1',), ('key1',)]) - self.assertEqual(er.entities.data, [(0, 'id11', 'url11'), (1, 'id12', 'url21')]) - self.assertEqual(er.objects.data, [(0, ref_container_1.object_id, 'Container', '', ''), - (1, ref_container_2.object_id, 'Container', '', '')]) + ref_container_1 = Container(name="Container_1") + ref_container_2 = Container(name="Container_2") + er.add_ref( + file=ExternalResourcesManagerContainer(name="file"), + container=ref_container_1, + key="key1", + entity_id="id11", + entity_uri="url11", + ) + er.add_ref( + file=ExternalResourcesManagerContainer(name="file"), + container=ref_container_2, + key="key1", + entity_id="id12", + entity_uri="url21", + ) + + self.assertEqual(er.keys.data, [("key1",), ("key1",)]) + self.assertEqual(er.entities.data, [(0, "id11", "url11"), (1, "id12", "url21")]) + self.assertEqual( + er.objects.data, + [(0, ref_container_1.object_id, "Container", "", ""), (1, ref_container_2.object_id, "Container", "", "")], + ) def test_add_ref_same_keyname(self): er = ExternalResources() - ref_container_1 = Container(name='Container_1') - ref_container_2 = Container(name='Container_2') - ref_container_3 = Container(name='Container_2') - er.add_ref(file=ExternalResourcesManagerContainer(name='file'), - container=ref_container_1, - key='key1', - entity_id="id11", - entity_uri='url11') - er.add_ref(file=ExternalResourcesManagerContainer(name='file'), - container=ref_container_2, - key='key1', - entity_id="id12", - entity_uri='url21') - er.add_ref(file=ExternalResourcesManagerContainer(name='file'), - container=ref_container_3, - key='key1', - entity_id="id13", - entity_uri='url31') - self.assertEqual(er.keys.data, [('key1',), ('key1',), ('key1',)]) + ref_container_1 = Container(name="Container_1") + ref_container_2 = Container(name="Container_2") + ref_container_3 = Container(name="Container_2") + er.add_ref( + file=ExternalResourcesManagerContainer(name="file"), + container=ref_container_1, + key="key1", + entity_id="id11", + entity_uri="url11", + ) + er.add_ref( + file=ExternalResourcesManagerContainer(name="file"), + container=ref_container_2, + key="key1", + entity_id="id12", + entity_uri="url21", + ) + er.add_ref( + file=ExternalResourcesManagerContainer(name="file"), + container=ref_container_3, + key="key1", + entity_id="id13", + entity_uri="url31", + ) + self.assertEqual(er.keys.data, [("key1",), ("key1",), ("key1",)]) + self.assertEqual(er.entities.data, [(0, "id11", "url11"), (1, "id12", "url21"), (2, "id13", "url31")]) self.assertEqual( - er.entities.data, - [(0, 'id11', 'url11'), - (1, 'id12', 'url21'), - (2, 'id13', 'url31')]) - self.assertEqual(er.objects.data, [(0, ref_container_1.object_id, 'Container', '', ''), - (1, ref_container_2.object_id, 'Container', '', ''), - (2, ref_container_3.object_id, 'Container', '', '')]) + er.objects.data, + [ + (0, ref_container_1.object_id, "Container", "", ""), + (1, ref_container_2.object_id, "Container", "", ""), + (2, ref_container_3.object_id, "Container", "", ""), + ], + ) def test_object_key_unqiueness(self): er = ExternalResources() - data = Data(name='data_name', data=np.array([('Mus musculus', 9, 81.0), ('Homo sapien', 3, 27.0)], - dtype=[('species', 'U14'), ('age', 'i4'), ('weight', 'f4')])) - - er.add_ref(file=ExternalResourcesManagerContainer(name='file'), - container=data, - key='Mus musculus', - entity_id='NCBI:txid10090', - entity_uri='https://www.ncbi.nlm.nih.gov/Taxonomy/Browser/wwwtax.cgi?id=10090') - existing_key = er.get_key('Mus musculus') - er.add_ref(file=ExternalResourcesManagerContainer(name='file'), - container=data, - key=existing_key, - entity_id='entity2', - entity_uri='entity_uri2') + data = Data( + name="data_name", + data=np.array( + [("Mus musculus", 9, 81.0), ("Homo sapien", 3, 27.0)], + dtype=[("species", "U14"), ("age", "i4"), ("weight", "f4")], + ), + ) + + er.add_ref( + file=ExternalResourcesManagerContainer(name="file"), + container=data, + key="Mus musculus", + entity_id="NCBI:txid10090", + entity_uri="https://www.ncbi.nlm.nih.gov/Taxonomy/Browser/wwwtax.cgi?id=10090", + ) + existing_key = er.get_key("Mus musculus") + er.add_ref( + file=ExternalResourcesManagerContainer(name="file"), + container=data, + key=existing_key, + entity_id="entity2", + entity_uri="entity_uri2", + ) self.assertEqual(er.object_keys.data, [(0, 0)]) def test_check_object_field_add(self): er = ExternalResources() - data = Data(name="species", data=['Homo sapiens', 'Mus musculus']) - er._check_object_field(file=ExternalResourcesManagerContainer(name='file'), - container=data, - relative_path='', - field='') + data = Data(name="species", data=["Homo sapiens", "Mus musculus"]) + er._check_object_field( + file=ExternalResourcesManagerContainer(name="file"), container=data, relative_path="", field="" + ) - self.assertEqual(er.objects.data, [(0, data.object_id, 'Data', '', '')]) + self.assertEqual(er.objects.data, [(0, data.object_id, "Data", "", "")]) def test_check_object_field_multi_files(self): er = ExternalResources() - data = Data(name="species", data=['Homo sapiens', 'Mus musculus']) - file = ExternalResourcesManagerContainer(name='file') + data = Data(name="species", data=["Homo sapiens", "Mus musculus"]) + file = ExternalResourcesManagerContainer(name="file") - er._check_object_field(file=file, container=data, relative_path='', field='') + er._check_object_field(file=file, container=data, relative_path="", field="") er._add_file(file.object_id) - data2 = Data(name="species", data=['Homo sapiens', 'Mus musculus']) + data2 = Data(name="species", data=["Homo sapiens", "Mus musculus"]) with self.assertRaises(ValueError): - er._check_object_field(file=file, container=data2, relative_path='', field='') + er._check_object_field(file=file, container=data2, relative_path="", field="") def test_check_object_field_multi_error(self): er = ExternalResources() - data = Data(name="species", data=['Homo sapiens', 'Mus musculus']) - er._check_object_field(file=ExternalResourcesManagerContainer(name='file'), - container=data, - relative_path='', - field='') - er._add_object(files_idx=0, container=data, relative_path='', field='') + data = Data(name="species", data=["Homo sapiens", "Mus musculus"]) + er._check_object_field( + file=ExternalResourcesManagerContainer(name="file"), container=data, relative_path="", field="" + ) + er._add_object(files_idx=0, container=data, relative_path="", field="") with self.assertRaises(ValueError): - er._check_object_field(file=ExternalResourcesManagerContainer(name='file'), - container=data, - relative_path='', - field='') + er._check_object_field( + file=ExternalResourcesManagerContainer(name="file"), container=data, relative_path="", field="" + ) def test_check_object_field_not_in_obj_table(self): er = ExternalResources() - data = Data(name="species", data=['Homo sapiens', 'Mus musculus']) + data = Data(name="species", data=["Homo sapiens", "Mus musculus"]) with self.assertRaises(ValueError): - er._check_object_field(file=ExternalResourcesManagerContainer(name='file'), - container=data, - relative_path='', - field='', - create=False) + er._check_object_field( + file=ExternalResourcesManagerContainer(name="file"), + container=data, + relative_path="", + field="", + create=False, + ) def test_add_ref_attribute(self): # Test to make sure the attribute object is being used for the id # for the external reference. - table = DynamicTable(name='table', description='table') - table.add_column(name='col1', description="column") - table.add_row(id=0, col1='data') - - er = ExternalResources() - er.add_ref(file=ExternalResourcesManagerContainer(name='file'), - container=table, - attribute='id', - key='key1', - entity_id='entity_0', - entity_uri='entity_0_uri') + table = DynamicTable(name="table", description="table") + table.add_column(name="col1", description="column") + table.add_row(id=0, col1="data") + + er = ExternalResources() + er.add_ref( + file=ExternalResourcesManagerContainer(name="file"), + container=table, + attribute="id", + key="key1", + entity_id="entity_0", + entity_uri="entity_0_uri", + ) - self.assertEqual(er.keys.data, [('key1',)]) - self.assertEqual(er.entities.data, [(0, 'entity_0', 'entity_0_uri')]) - self.assertEqual(er.objects.data, [(0, table.id.object_id, 'ElementIdentifiers', '', '')]) + self.assertEqual(er.keys.data, [("key1",)]) + self.assertEqual(er.entities.data, [(0, "entity_0", "entity_0_uri")]) + self.assertEqual(er.objects.data, [(0, table.id.object_id, "ElementIdentifiers", "", "")]) def test_add_ref_column_as_attribute(self): # Test to make sure the attribute object is being used for the id # for the external reference. - table = DynamicTable(name='table', description='table') - table.add_column(name='col1', description="column") - table.add_row(id=0, col1='data') - - er = ExternalResources() - er.add_ref(file=ExternalResourcesManagerContainer(name='file'), - container=table, - attribute='col1', - key='key1', - entity_id='entity_0', - entity_uri='entity_0_uri') + table = DynamicTable(name="table", description="table") + table.add_column(name="col1", description="column") + table.add_row(id=0, col1="data") + + er = ExternalResources() + er.add_ref( + file=ExternalResourcesManagerContainer(name="file"), + container=table, + attribute="col1", + key="key1", + entity_id="entity_0", + entity_uri="entity_0_uri", + ) - self.assertEqual(er.keys.data, [('key1',)]) - self.assertEqual(er.entities.data, [(0, 'entity_0', 'entity_0_uri')]) - self.assertEqual(er.objects.data, [(0, table['col1'].object_id, 'VectorData', '', '')]) + self.assertEqual(er.keys.data, [("key1",)]) + self.assertEqual(er.entities.data, [(0, "entity_0", "entity_0_uri")]) + self.assertEqual(er.objects.data, [(0, table["col1"].object_id, "VectorData", "", "")]) def test_add_ref_compound_data(self): er = ExternalResources() data = Data( - name='data_name', + name="data_name", data=np.array( - [('Mus musculus', 9, 81.0), ('Homo sapiens', 3, 27.0)], - dtype=[('species', 'U14'), ('age', 'i4'), ('weight', 'f4')])) - er.add_ref(file=ExternalResourcesManagerContainer(name='file'), - container=data, - field='species', - key='Mus musculus', - entity_id='NCBI:txid10090', - entity_uri='entity_0_uri') - - self.assertEqual(er.keys.data, [('Mus musculus',)]) - self.assertEqual(er.entities.data, [(0, 'NCBI:txid10090', 'entity_0_uri')]) - self.assertEqual(er.objects.data, [(0, data.object_id, 'Data', '', 'species')]) + [("Mus musculus", 9, 81.0), ("Homo sapiens", 3, 27.0)], + dtype=[("species", "U14"), ("age", "i4"), ("weight", "f4")], + ), + ) + er.add_ref( + file=ExternalResourcesManagerContainer(name="file"), + container=data, + field="species", + key="Mus musculus", + entity_id="NCBI:txid10090", + entity_uri="entity_0_uri", + ) + + self.assertEqual(er.keys.data, [("Mus musculus",)]) + self.assertEqual(er.entities.data, [(0, "NCBI:txid10090", "entity_0_uri")]) + self.assertEqual(er.objects.data, [(0, data.object_id, "Data", "", "species")]) def test_roundtrip(self): read_container = self.roundtripContainer() @@ -776,183 +775,172 @@ def test_roundtrip_export(self): class TestExternalResourcesNestedAttributes(TestCase): - def setUp(self): - self.attr1 = AttributeSpec(name='attr1', doc='a string attribute', dtype='text') - self.attr2 = AttributeSpec(name='attr2', doc='an integer attribute', dtype='int') - self.attr3 = AttributeSpec(name='attr3', doc='an integer attribute', dtype='int') + self.attr1 = AttributeSpec(name="attr1", doc="a string attribute", dtype="text") + self.attr2 = AttributeSpec(name="attr2", doc="an integer attribute", dtype="int") + self.attr3 = AttributeSpec(name="attr3", doc="an integer attribute", dtype="int") self.bar_spec = GroupSpec( - doc='A test group specification with a data type', - data_type_def='Bar', + doc="A test group specification with a data type", + data_type_def="Bar", datasets=[ DatasetSpec( - doc='a dataset', - dtype='int', - name='data', - attributes=[self.attr2] + doc="a dataset", + dtype="int", + name="data", + attributes=[self.attr2], ) ], - attributes=[self.attr1]) + attributes=[self.attr1], + ) specs = [self.bar_spec] - containers = {'Bar': Bar} + containers = {"Bar": Bar} self.type_map = create_test_type_map(specs, containers) self.spec_catalog = self.type_map.namespace_catalog.get_namespace(CORE_NAMESPACE).catalog self.cls = self.type_map.get_dt_container_cls(self.bar_spec.data_type) - self.bar = self.cls(name='bar', data=[1], attr1='attr1', attr2=1) + self.bar = self.cls(name="bar", data=[1], attr1="attr1", attr2=1) obj_mapper_bar = self.type_map.get_map(self.bar) - obj_mapper_bar.map_spec('attr2', spec=self.attr2) + obj_mapper_bar.map_spec("attr2", spec=self.attr2) def test_add_ref_nested(self): - table = DynamicTable(name='table', description='table') - table.add_column(name='col1', description="column") - table.add_row(id=0, col1='data') - - er = ExternalResources() - er.add_ref(file=ExternalResourcesManagerContainer(name='file'), - container=table, - attribute='description', - key='key1', - entity_id='entity_0', - entity_uri='entity_0_uri') - self.assertEqual(er.keys.data, [('key1',)]) - self.assertEqual(er.entities.data, [(0, 'entity_0', 'entity_0_uri')]) - self.assertEqual(er.objects.data, [(0, table.object_id, 'DynamicTable', 'description', '')]) + table = DynamicTable(name="table", description="table") + table.add_column(name="col1", description="column") + table.add_row(id=0, col1="data") + + er = ExternalResources() + er.add_ref( + file=ExternalResourcesManagerContainer(name="file"), + container=table, + attribute="description", + key="key1", + entity_id="entity_0", + entity_uri="entity_0_uri", + ) + self.assertEqual(er.keys.data, [("key1",)]) + self.assertEqual(er.entities.data, [(0, "entity_0", "entity_0_uri")]) + self.assertEqual(er.objects.data, [(0, table.object_id, "DynamicTable", "description", "")]) def test_add_ref_deep_nested(self): er = ExternalResources(type_map=self.type_map) - er.add_ref(file=ExternalResourcesManagerContainer(name='file'), - container=self.bar, - attribute='attr2', - key='key1', - entity_id='entity_0', - entity_uri='entity_0_uri') - self.assertEqual(er.objects.data[0][3], 'data/attr2', '') + er.add_ref( + file=ExternalResourcesManagerContainer(name="file"), + container=self.bar, + attribute="attr2", + key="key1", + entity_id="entity_0", + entity_uri="entity_0_uri", + ) + self.assertEqual(er.objects.data[0][3], "data/attr2", "") class TestExternalResourcesGetKey(TestCase): - def setUp(self): self.er = ExternalResources() def test_get_key_error_more_info(self): - self.er.add_ref(file=ExternalResourcesManagerContainer(name='file'), - container=Container(name='Container'), - key='key1', - entity_id="id11", - entity_uri='url11') - self.er.add_ref(file=ExternalResourcesManagerContainer(name='file'), - container=Container(name='Container'), - key='key1', - entity_id="id12", - entity_uri='url21') + self.er.add_ref( + file=ExternalResourcesManagerContainer(name="file"), + container=Container(name="Container"), + key="key1", + entity_id="id11", + entity_uri="url11", + ) + self.er.add_ref( + file=ExternalResourcesManagerContainer(name="file"), + container=Container(name="Container"), + key="key1", + entity_id="id12", + entity_uri="url21", + ) msg = "There are more than one key with that name. Please search with additional information." with self.assertRaisesWith(ValueError, msg): - _ = self.er.get_key(key_name='key1') + _ = self.er.get_key(key_name="key1") def test_get_key(self): - self.er.add_ref(file=ExternalResourcesManagerContainer(name='file'), - container=Container(name='Container'), - key='key1', - entity_id="id11", - entity_uri='url11') + self.er.add_ref( + file=ExternalResourcesManagerContainer(name="file"), + container=Container(name="Container"), + key="key1", + entity_id="id11", + entity_uri="url11", + ) - key = self.er.get_key(key_name='key1') + key = self.er.get_key(key_name="key1") self.assertIsInstance(key, Key) self.assertEqual(key.idx, 0) def test_get_key_bad_arg(self): - self.er.add_ref(file=ExternalResourcesManagerContainer(name='file'), - container=Container(name='Container'), - key='key1', - entity_id="id11", - entity_uri='url11') + self.er.add_ref( + file=ExternalResourcesManagerContainer(name="file"), + container=Container(name="Container"), + key="key1", + entity_id="id11", + entity_uri="url11", + ) with self.assertRaises(ValueError): - self.er.get_key(key_name='key2') + self.er.get_key(key_name="key2") def test_get_key_file_container_provided(self): file = ExternalResourcesManagerContainer() - container1 = Container(name='Container') - self.er.add_ref(file=file, - container=container1, - key='key1', - entity_id="id11", - entity_uri='url11') - self.er.add_ref(file=file, - container=Container(name='Container'), - key='key1', - entity_id="id12", - entity_uri='url21') - - key = self.er.get_key(key_name='key1', container=container1, file=file) + container1 = Container(name="Container") + self.er.add_ref(file=file, container=container1, key="key1", entity_id="id11", entity_uri="url11") + self.er.add_ref( + file=file, container=Container(name="Container"), key="key1", entity_id="id12", entity_uri="url21" + ) + + key = self.er.get_key(key_name="key1", container=container1, file=file) self.assertIsInstance(key, Key) self.assertEqual(key.idx, 0) def test_get_key_no_file_container_provided(self): file = ExternalResourcesManagerContainer() - self.er.add_ref(container=file, key='key1', entity_id="id11", entity_uri='url11') + self.er.add_ref(container=file, key="key1", entity_id="id11", entity_uri="url11") - key = self.er.get_key(key_name='key1', container=file) + key = self.er.get_key(key_name="key1", container=file) self.assertIsInstance(key, Key) self.assertEqual(key.idx, 0) def test_get_key_no_file_nested_container_provided(self): file = ExternalResourcesManagerContainer() - container1 = Container(name='Container') + container1 = Container(name="Container") container1.parent = file - self.er.add_ref(file=file, - container=container1, - key='key1', - entity_id="id11", - entity_uri='url11') + self.er.add_ref(file=file, container=container1, key="key1", entity_id="id11", entity_uri="url11") - key = self.er.get_key(key_name='key1', container=container1) + key = self.er.get_key(key_name="key1", container=container1) self.assertIsInstance(key, Key) self.assertEqual(key.idx, 0) def test_get_key_no_file_deep_nested_container_provided(self): file = ExternalResourcesManagerContainer() - container1 = Container(name='Container1') - container2 = Container(name='Container2') + container1 = Container(name="Container1") + container2 = Container(name="Container2") container1.parent = file container2.parent = container1 - self.er.add_ref(file=file, - container=container2, - key='key1', - entity_id="id11", - entity_uri='url11') + self.er.add_ref(file=file, container=container2, key="key1", entity_id="id11", entity_uri="url11") - key = self.er.get_key(key_name='key1', container=container2) + key = self.er.get_key(key_name="key1", container=container2) self.assertIsInstance(key, Key) self.assertEqual(key.idx, 0) def test_get_key_no_file_error(self): file = ExternalResourcesManagerContainer() - container1 = Container(name='Container') - self.er.add_ref(file=file, - container=container1, - key='key1', - entity_id="id11", - entity_uri='url11') + container1 = Container(name="Container") + self.er.add_ref(file=file, container=container1, key="key1", entity_id="id11", entity_uri="url11") with self.assertRaises(ValueError): - _ = self.er.get_key(key_name='key1', container=container1) + _ = self.er.get_key(key_name="key1", container=container1) def test_get_key_no_key_found(self): file = ExternalResourcesManagerContainer() - container1 = Container(name='Container') - self.er.add_ref(file=file, - container=container1, - key='key1', - entity_id="id11", - entity_uri='url11') + container1 = Container(name="Container") + self.er.add_ref(file=file, container=container1, key="key1", entity_id="id11", entity_uri="url11") msg = "No key found with that container." with self.assertRaisesWith(ValueError, msg): - _ = self.er.get_key(key_name='key2', container=container1, file=file) + _ = self.er.get_key(key_name="key2", container=container1, file=file) diff --git a/tests/unit/common/test_sparse.py b/tests/unit/common/test_sparse.py index 7d94231f4..778aebf19 100644 --- a/tests/unit/common/test_sparse.py +++ b/tests/unit/common/test_sparse.py @@ -1,12 +1,11 @@ -from hdmf.common import CSRMatrix -from hdmf.testing import TestCase, H5RoundTripMixin - -import scipy.sparse as sps import numpy as np +import scipy.sparse as sps +from hdmf.common import CSRMatrix +from hdmf.testing import H5RoundTripMixin, TestCase -class TestCSRMatrix(TestCase): +class TestCSRMatrix(TestCase): def test_from_sparse_matrix(self): data = np.array([1, 2, 3, 4, 5, 6]) indices = np.array([0, 2, 2, 0, 1, 2]) @@ -80,36 +79,84 @@ def test_valueerror_non_2D_shape(self): data = np.array([1, 2, 3, 4, 5, 6]) indices = np.array([0, 2, 2, 0, 1, 2]) indptr = np.array([0, 2, 3, 6]) - with self.assertRaisesWith(ValueError, "'shape' argument must specify two and only two dimensions."): + with self.assertRaisesWith( + ValueError, + "'shape' argument must specify two and only two dimensions.", + ): _ = CSRMatrix(data=data, indices=indices, indptr=indptr, shape=(3, 3, 1)) - with self.assertRaisesWith(ValueError, "'shape' argument must specify two and only two dimensions."): - _ = CSRMatrix(data=data, indices=indices, indptr=indptr, shape=(9, )) + with self.assertRaisesWith( + ValueError, + "'shape' argument must specify two and only two dimensions.", + ): + _ = CSRMatrix(data=data, indices=indices, indptr=indptr, shape=(9,)) def test_valueerror_non_1d_indptr_or_indicies(self): data = np.array([1, 2, 3, 4, 5, 6]) indices = np.array([0, 2, 2, 0, 1, 2]) indptr = np.array([0, 2, 3, 6]) - with self.assertRaisesWith(ValueError, "'indices' must be a 1D array of unsigned integers."): - _ = CSRMatrix(data=data, indices=indices.reshape((3, 2)), indptr=indptr, shape=(3, 3)) - with self.assertRaisesWith(ValueError, "'indptr' must be a 1D array of unsigned integers."): - _ = CSRMatrix(data=data, indices=indices, indptr=indptr.reshape((2, 2)), shape=(3, 3)) + with self.assertRaisesWith(ValueError, "'indices' must be a 1D array of unsigned integers."): + _ = CSRMatrix( + data=data, + indices=indices.reshape((3, 2)), + indptr=indptr, + shape=(3, 3), + ) + with self.assertRaisesWith(ValueError, "'indptr' must be a 1D array of unsigned integers."): + _ = CSRMatrix( + data=data, + indices=indices, + indptr=indptr.reshape((2, 2)), + shape=(3, 3), + ) def test_valueerror_non_int_indptr_or_indicies(self): data = np.array([1, 2, 3, 4, 5, 6]) indices = np.array([0, 2, 2, 0, 1, 2]) indptr = np.array([0, 2, 3, 6]) # test indices numpy array of floats - with self.assertRaisesWith(ValueError, "Cannot convert 'indices' to an array of unsigned integers."): - _ = CSRMatrix(data=data, indices=indices.astype(float), indptr=indptr, shape=(3, 3)) + with self.assertRaisesWith( + ValueError, + "Cannot convert 'indices' to an array of unsigned integers.", + ): + _ = CSRMatrix( + data=data, + indices=indices.astype(float), + indptr=indptr, + shape=(3, 3), + ) # test indptr numpy array of floats - with self.assertRaisesWith(ValueError, "Cannot convert 'indptr' to an array of unsigned integers."): - _ = CSRMatrix(data=data, indices=indices, indptr=indptr.astype(float), shape=(3, 3)) + with self.assertRaisesWith( + ValueError, + "Cannot convert 'indptr' to an array of unsigned integers.", + ): + _ = CSRMatrix( + data=data, + indices=indices, + indptr=indptr.astype(float), + shape=(3, 3), + ) # test indices list of floats - with self.assertRaisesWith(ValueError, "Cannot convert 'indices' to an array of unsigned integers."): - _ = CSRMatrix(data=data, indices=indices.astype(float).tolist(), indptr=indptr, shape=(3, 3)) + with self.assertRaisesWith( + ValueError, + "Cannot convert 'indices' to an array of unsigned integers.", + ): + _ = CSRMatrix( + data=data, + indices=indices.astype(float).tolist(), + indptr=indptr, + shape=(3, 3), + ) # test indptr list of floats - with self.assertRaisesWith(ValueError, "Cannot convert 'indptr' to an array of unsigned integers."): - _ = CSRMatrix(data=data, indices=indices, indptr=indptr.astype(float).tolist(), shape=(3, 3)) + with self.assertRaisesWith( + ValueError, + "Cannot convert 'indptr' to an array of unsigned integers.", + ): + _ = CSRMatrix( + data=data, + indices=indices, + indptr=indptr.astype(float).tolist(), + shape=(3, 3), + ) def test_constructor_indices_missing(self): data = np.array([1, 2, 3, 4, 5, 6]) @@ -139,7 +186,7 @@ def test_constructor_bad_shape(self): data = np.array([1, 2, 3, 4, 5, 6]) indices = np.array([0, 2, 2, 0, 1, 2]) indptr = np.array([0, 2, 3, 6]) - shape = (3, ) + shape = (3,) msg = "'shape' argument must specify two and only two dimensions." with self.assertRaisesWith(ValueError, msg): CSRMatrix(data=data, indices=indices, indptr=indptr, shape=shape) @@ -155,7 +202,6 @@ def test_array_bad_dim(self): class TestCSRMatrixRoundTrip(H5RoundTripMixin, TestCase): - def setUpContainer(self): data = np.array([1, 2, 3, 4, 5, 6]) indices = np.array([0, 2, 2, 0, 1, 2]) diff --git a/tests/unit/common/test_table.py b/tests/unit/common/test_table.py index ad57b56a1..ed111e58e 100644 --- a/tests/unit/common/test_table.py +++ b/tests/unit/common/test_table.py @@ -1,169 +1,217 @@ +import os +import unittest from collections import OrderedDict + import h5py import numpy as np -import os import pandas as pd -import unittest from hdmf import Container -from hdmf.backends.hdf5 import H5DataIO, HDF5IO +from hdmf.backends.hdf5 import HDF5IO, H5DataIO from hdmf.backends.hdf5.h5tools import H5_TEXT, H5PY_3 -from hdmf.common import (DynamicTable, VectorData, VectorIndex, ElementIdentifiers, EnumData, - DynamicTableRegion, get_manager, SimpleMultiContainer) -from hdmf.testing import TestCase, H5RoundTripMixin, remove_test_file +from hdmf.common import ( + DynamicTable, + DynamicTableRegion, + ElementIdentifiers, + EnumData, + SimpleMultiContainer, + VectorData, + VectorIndex, + get_manager, +) +from hdmf.testing import H5RoundTripMixin, TestCase, remove_test_file from hdmf.utils import StrDataset -from tests.unit.helpers.utils import get_temp_filepath +from ..helpers.utils import get_temp_filepath class TestDynamicTable(TestCase): - def setUp(self): self.spec = [ - {'name': 'foo', 'description': 'foo column'}, - {'name': 'bar', 'description': 'bar column'}, - {'name': 'baz', 'description': 'baz column'}, + {"name": "foo", "description": "foo column"}, + {"name": "bar", "description": "bar column"}, + {"name": "baz", "description": "baz column"}, ] self.data = [ [1, 2, 3, 4, 5], [10.0, 20.0, 30.0, 40.0, 50.0], - ['cat', 'dog', 'bird', 'fish', 'lizard'] + ["cat", "dog", "bird", "fish", "lizard"], ] def with_table_columns(self): cols = [VectorData(**d) for d in self.spec] - table = DynamicTable(name="with_table_columns", description='a test table', columns=cols) + table = DynamicTable(name="with_table_columns", description="a test table", columns=cols) return table def with_columns_and_data(self): columns = [ - VectorData(name=s['name'], description=s['description'], data=d) - for s, d in zip(self.spec, self.data) + VectorData(name=s["name"], description=s["description"], data=d) for s, d in zip(self.spec, self.data) ] - return DynamicTable(name="with_columns_and_data", description='a test table', columns=columns) + return DynamicTable( + name="with_columns_and_data", + description="a test table", + columns=columns, + ) def with_spec(self): - table = DynamicTable(name="with_spec", description='a test table', columns=self.spec) + table = DynamicTable(name="with_spec", description="a test table", columns=self.spec) return table def check_empty_table(self, table): self.assertIsInstance(table.columns, tuple) self.assertIsInstance(table.columns[0], VectorData) self.assertEqual(len(table.columns), 3) - self.assertTupleEqual(table.colnames, ('foo', 'bar', 'baz')) + self.assertTupleEqual(table.colnames, ("foo", "bar", "baz")) def test_constructor_table_columns(self): table = self.with_table_columns() - self.assertEqual(table.name, 'with_table_columns') + self.assertEqual(table.name, "with_table_columns") self.check_empty_table(table) def test_constructor_spec(self): table = self.with_spec() - self.assertEqual(table.name, 'with_spec') + self.assertEqual(table.name, "with_spec") self.check_empty_table(table) def check_table(self, table): self.assertEqual(len(table), 5) self.assertEqual(table.columns[0].data, [1, 2, 3, 4, 5]) self.assertEqual(table.columns[1].data, [10.0, 20.0, 30.0, 40.0, 50.0]) - self.assertEqual(table.columns[2].data, ['cat', 'dog', 'bird', 'fish', 'lizard']) + self.assertEqual(table.columns[2].data, ["cat", "dog", "bird", "fish", "lizard"]) self.assertEqual(table.id.data, [0, 1, 2, 3, 4]) - self.assertTrue(hasattr(table, 'baz')) + self.assertTrue(hasattr(table, "baz")) def test_constructor_ids_default(self): - columns = [VectorData(name=s['name'], description=s['description'], data=d) - for s, d in zip(self.spec, self.data)] - table = DynamicTable(name="with_spec", description='a test table', columns=columns) + columns = [ + VectorData(name=s["name"], description=s["description"], data=d) for s, d in zip(self.spec, self.data) + ] + table = DynamicTable(name="with_spec", description="a test table", columns=columns) self.check_table(table) def test_constructor_ids(self): - columns = [VectorData(name=s['name'], description=s['description'], data=d) - for s, d in zip(self.spec, self.data)] - table = DynamicTable(name="with_columns", description='a test table', id=[0, 1, 2, 3, 4], columns=columns) + columns = [ + VectorData(name=s["name"], description=s["description"], data=d) for s, d in zip(self.spec, self.data) + ] + table = DynamicTable( + name="with_columns", + description="a test table", + id=[0, 1, 2, 3, 4], + columns=columns, + ) self.check_table(table) def test_constructor_ElementIdentifier_ids(self): - columns = [VectorData(name=s['name'], description=s['description'], data=d) - for s, d in zip(self.spec, self.data)] - ids = ElementIdentifiers(name='ids', data=[0, 1, 2, 3, 4]) - table = DynamicTable(name="with_columns", description='a test table', id=ids, columns=columns) + columns = [ + VectorData(name=s["name"], description=s["description"], data=d) for s, d in zip(self.spec, self.data) + ] + ids = ElementIdentifiers(name="ids", data=[0, 1, 2, 3, 4]) + table = DynamicTable( + name="with_columns", + description="a test table", + id=ids, + columns=columns, + ) self.check_table(table) def test_constructor_ids_bad_ids(self): - columns = [VectorData(name=s['name'], description=s['description'], data=d) - for s, d in zip(self.spec, self.data)] + columns = [ + VectorData(name=s["name"], description=s["description"], data=d) for s, d in zip(self.spec, self.data) + ] msg = "must provide same number of ids as length of columns" with self.assertRaisesWith(ValueError, msg): - DynamicTable(name="with_columns", description='a test table', id=[0, 1], columns=columns) + DynamicTable( + name="with_columns", + description="a test table", + id=[0, 1], + columns=columns, + ) def test_constructor_bad_columns(self): - columns = ['bad_column'] + columns = ["bad_column"] msg = "'columns' must be a list of dict, VectorData, DynamicTableRegion, or VectorIndex" with self.assertRaisesWith(ValueError, msg): - DynamicTable(name="with_columns", description='a test table', columns=columns) + DynamicTable(name="with_columns", description="a test table", columns=columns) def test_constructor_unequal_length_columns(self): - columns = [VectorData(name='col1', description='desc', data=[1, 2, 3]), - VectorData(name='col2', description='desc', data=[1, 2])] + columns = [ + VectorData(name="col1", description="desc", data=[1, 2, 3]), + VectorData(name="col2", description="desc", data=[1, 2]), + ] msg = "columns must be the same length" with self.assertRaisesWith(ValueError, msg): - DynamicTable(name="with_columns", description='a test table', columns=columns) + DynamicTable(name="with_columns", description="a test table", columns=columns) def test_constructor_colnames(self): """Test that passing colnames correctly sets the order of the columns.""" cols = [VectorData(**d) for d in self.spec] - table = DynamicTable(name="with_columns", description='a test table', - columns=cols, colnames=['baz', 'bar', 'foo']) + table = DynamicTable( + name="with_columns", + description="a test table", + columns=cols, + colnames=["baz", "bar", "foo"], + ) self.assertTupleEqual(table.columns, tuple(cols[::-1])) def test_constructor_colnames_no_columns(self): """Test that passing colnames without columns raises an error.""" msg = "Must supply 'columns' if specifying 'colnames'" with self.assertRaisesWith(ValueError, msg): - DynamicTable(name="with_columns", description='a test table', colnames=['baz', 'bar', 'foo']) + DynamicTable( + name="with_columns", + description="a test table", + colnames=["baz", "bar", "foo"], + ) def test_constructor_colnames_vectorindex(self): """Test that passing colnames with a VectorIndex column puts the index in the right location in columns.""" cols = [VectorData(**d) for d in self.spec] - ind = VectorIndex(name='foo_index', data=list(), target=cols[0]) + ind = VectorIndex(name="foo_index", data=list(), target=cols[0]) cols.append(ind) - table = DynamicTable(name="with_columns", description='a test table', columns=cols, - colnames=['baz', 'bar', 'foo']) + table = DynamicTable( + name="with_columns", + description="a test table", + columns=cols, + colnames=["baz", "bar", "foo"], + ) self.assertTupleEqual(table.columns, (cols[2], cols[1], ind, cols[0])) def test_constructor_colnames_vectorindex_rev(self): """Test that passing colnames with a VectorIndex column puts the index in the right location in columns.""" cols = [VectorData(**d) for d in self.spec] - ind = VectorIndex(name='foo_index', data=list(), target=cols[0]) + ind = VectorIndex(name="foo_index", data=list(), target=cols[0]) cols.insert(0, ind) # put index before its target - table = DynamicTable(name="with_columns", description='a test table', columns=cols, - colnames=['baz', 'bar', 'foo']) + table = DynamicTable( + name="with_columns", + description="a test table", + columns=cols, + colnames=["baz", "bar", "foo"], + ) self.assertTupleEqual(table.columns, (cols[3], cols[2], ind, cols[1])) def test_constructor_dup_index(self): """Test that passing two indices for the same column raises an error.""" cols = [VectorData(**d) for d in self.spec] - cols.append(VectorIndex(name='foo_index', data=list(), target=cols[0])) - cols.append(VectorIndex(name='foo_index2', data=list(), target=cols[0])) + cols.append(VectorIndex(name="foo_index", data=list(), target=cols[0])) + cols.append(VectorIndex(name="foo_index2", data=list(), target=cols[0])) msg = "'columns' contains index columns with the same target: ['foo', 'foo']" with self.assertRaisesWith(ValueError, msg): - DynamicTable(name="with_columns", description='a test table', columns=cols) + DynamicTable(name="with_columns", description="a test table", columns=cols) def test_constructor_index_missing_target(self): """Test that passing an index without its target raises an error.""" cols = [VectorData(**d) for d in self.spec] missing_col = cols.pop(2) - cols.append(VectorIndex(name='foo_index', data=list(), target=missing_col)) + cols.append(VectorIndex(name="foo_index", data=list(), target=missing_col)) msg = "Found VectorIndex 'foo_index' but not its target 'baz'" with self.assertRaisesWith(ValueError, msg): - DynamicTable(name="with_columns", description='a test table', columns=cols) + DynamicTable(name="with_columns", description="a test table", columns=cols) def add_rows(self, table): - table.add_row({'foo': 1, 'bar': 10.0, 'baz': 'cat'}) - table.add_row({'foo': 2, 'bar': 20.0, 'baz': 'dog'}) - table.add_row({'foo': 3, 'bar': 30.0, 'baz': 'bird'}) - table.add_row({'foo': 4, 'bar': 40.0, 'baz': 'fish'}) - table.add_row({'foo': 5, 'bar': 50.0, 'baz': 'lizard'}) + table.add_row({"foo": 1, "bar": 10.0, "baz": "cat"}) + table.add_row({"foo": 2, "bar": 20.0, "baz": "dog"}) + table.add_row({"foo": 3, "bar": 30.0, "baz": "bird"}) + table.add_row({"foo": 4, "bar": 40.0, "baz": "fish"}) + table.add_row({"foo": 5, "bar": 50.0, "baz": "lizard"}) def test_add_row(self): table = self.with_spec() @@ -173,18 +221,18 @@ def test_add_row(self): def test_get(self): table = self.with_spec() self.add_rows(table) - self.assertIsInstance(table.get('foo'), VectorData) - self.assertEqual(table.get('foo'), table['foo']) + self.assertIsInstance(table.get("foo"), VectorData) + self.assertEqual(table.get("foo"), table["foo"]) def test_get_not_found(self): table = self.with_spec() self.add_rows(table) - self.assertIsNone(table.get('qux')) + self.assertIsNone(table.get("qux")) def test_get_not_found_default(self): table = self.with_spec() self.add_rows(table) - self.assertEqual(table.get('qux', 1), 1) + self.assertEqual(table.get("qux", 1), 1) def test_get_item(self): table = self.with_spec() @@ -193,41 +241,35 @@ def test_get_item(self): def test_add_column(self): table = self.with_spec() - table.add_column(name='qux', description='qux column') - self.assertTupleEqual(table.colnames, ('foo', 'bar', 'baz', 'qux')) - self.assertTrue(hasattr(table, 'qux')) + table.add_column(name="qux", description="qux column") + self.assertTupleEqual(table.colnames, ("foo", "bar", "baz", "qux")) + self.assertTrue(hasattr(table, "qux")) def test_add_column_twice(self): table = self.with_spec() - table.add_column(name='qux', description='qux column') + table.add_column(name="qux", description="qux column") msg = "column 'qux' already exists in DynamicTable 'with_spec'" with self.assertRaisesWith(ValueError, msg): - table.add_column(name='qux', description='qux column') + table.add_column(name="qux", description="qux column") def test_add_column_vectorindex(self): table = self.with_spec() - table.add_column(name='qux', description='qux column') - ind = VectorIndex(name='quux', data=list(), target=table['qux']) + table.add_column(name="qux", description="qux column") + ind = VectorIndex(name="quux", data=list(), target=table["qux"]) - msg = ("Passing a VectorIndex in for index may lead to unexpected behavior. This functionality will be " - "deprecated in a future version of HDMF.") + msg = ( + "Passing a VectorIndex in for index may lead to unexpected behavior. This" + " functionality will be deprecated in a future version of HDMF." + ) with self.assertWarnsWith(FutureWarning, msg): - table.add_column(name='bad', description='bad column', index=ind) + table.add_column(name="bad", description="bad column", index=ind) def test_add_column_multi_index(self): table = self.with_spec() - table.add_column(name='qux', description='qux column', index=2) - table.add_row(foo=5, bar=50.0, baz='lizard', - qux=[ - [1, 2, 3], - [1, 2, 3, 4] - ]) - table.add_row(foo=5, bar=50.0, baz='lizard', - qux=[ - [1, 2] - ] - ) + table.add_column(name="qux", description="qux column", index=2) + table.add_row(foo=5, bar=50.0, baz="lizard", qux=[[1, 2, 3], [1, 2, 3, 4]]) + table.add_row(foo=5, bar=50.0, baz="lizard", qux=[[1, 2]]) def test_add_column_auto_index_int(self): """ @@ -235,19 +277,21 @@ def test_add_column_auto_index_int(self): with index=1 as parameter """ table = self.with_spec() - table.add_row(foo=5, bar=50.0, baz='lizard') - table.add_row(foo=5, bar=50.0, baz='lizard') - expected = [[1, 2, 3], - [1, 2, 3, 4]] - table.add_column(name='qux', - description='qux column', - data=expected, - index=1) - self.assertListEqual(table['qux'][:], expected) + table.add_row(foo=5, bar=50.0, baz="lizard") + table.add_row(foo=5, bar=50.0, baz="lizard") + expected = [[1, 2, 3], [1, 2, 3, 4]] + table.add_column(name="qux", description="qux column", data=expected, index=1) + self.assertListEqual(table["qux"][:], expected) self.assertListEqual(table.qux_index.data, [3, 7]) # Add more rows after we created the column - table.add_row(foo=5, bar=50.0, baz='lizard', qux=[10, 11, 12]) - self.assertListEqual(table['qux'][:], expected + [[10, 11, 12], ]) + table.add_row(foo=5, bar=50.0, baz="lizard", qux=[10, 11, 12]) + self.assertListEqual( + table["qux"][:], + expected + + [ + [10, 11, 12], + ], + ) self.assertListEqual(table.qux_index.data, [3, 7, 10]) def test_add_column_auto_index_bool(self): @@ -256,19 +300,21 @@ def test_add_column_auto_index_bool(self): with index=True as parameter """ table = self.with_spec() - table.add_row(foo=5, bar=50.0, baz='lizard') - table.add_row(foo=5, bar=50.0, baz='lizard') - expected = [[1, 2, 3], - [1, 2, 3, 4]] - table.add_column(name='qux', - description='qux column', - data=expected, - index=True) - self.assertListEqual(table['qux'][:], expected) + table.add_row(foo=5, bar=50.0, baz="lizard") + table.add_row(foo=5, bar=50.0, baz="lizard") + expected = [[1, 2, 3], [1, 2, 3, 4]] + table.add_column(name="qux", description="qux column", data=expected, index=True) + self.assertListEqual(table["qux"][:], expected) self.assertListEqual(table.qux_index.data, [3, 7]) # Add more rows after we created the column - table.add_row(foo=5, bar=50.0, baz='lizard', qux=[10, 11, 12]) - self.assertListEqual(table['qux'][:], expected + [[10, 11, 12], ]) + table.add_row(foo=5, bar=50.0, baz="lizard", qux=[10, 11, 12]) + self.assertListEqual( + table["qux"][:], + expected + + [ + [10, 11, 12], + ], + ) self.assertListEqual(table.qux_index.data, [3, 7, 10]) def test_add_column_auto_multi_index_int(self): @@ -277,20 +323,31 @@ def test_add_column_auto_multi_index_int(self): two VectorIndex for the column so we set index=2 as parameter """ table = self.with_spec() - table.add_row(foo=5, bar=50.0, baz='lizard') - table.add_row(foo=5, bar=50.0, baz='lizard') - expected = [[[1, 2, 3], [1]], - [[1, 2, 3, 4], [1, 2]]] - table.add_column(name='qux', - description='qux column', - data=expected, - index=2) - self.assertListEqual(table['qux'][:], expected) + table.add_row(foo=5, bar=50.0, baz="lizard") + table.add_row(foo=5, bar=50.0, baz="lizard") + expected = [[[1, 2, 3], [1]], [[1, 2, 3, 4], [1, 2]]] + table.add_column(name="qux", description="qux column", data=expected, index=2) + self.assertListEqual(table["qux"][:], expected) self.assertListEqual(table.qux_index_index.data, [2, 4]) self.assertListEqual(table.qux_index.data, [3, 4, 8, 10]) # Add more rows after we created the column - table.add_row(foo=5, bar=50.0, baz='lizard', qux=[[10, 11, 12], ]) - self.assertListEqual(table['qux'][:], expected + [[[10, 11, 12], ]]) + table.add_row( + foo=5, + bar=50.0, + baz="lizard", + qux=[ + [10, 11, 12], + ], + ) + self.assertListEqual( + table["qux"][:], + expected + + [ + [ + [10, 11, 12], + ] + ], + ) self.assertListEqual(table.qux_index_index.data, [2, 4, 5]) self.assertListEqual(table.qux_index.data, [3, 4, 8, 10, 13]) @@ -300,29 +357,27 @@ def test_add_column_auto_multi_index_int_bad_index_levels(self): a two-level index, but we ask for either too many or too few index levels. """ table = self.with_spec() - table.add_row(foo=5, bar=50.0, baz='lizard') - table.add_row(foo=5, bar=50.0, baz='lizard') - expected = [[[1, 2, 3], [1]], - [[1, 2, 3, 4], [1, 2]]] + table.add_row(foo=5, bar=50.0, baz="lizard") + table.add_row(foo=5, bar=50.0, baz="lizard") + expected = [[[1, 2, 3], [1]], [[1, 2, 3, 4], [1, 2]]] msg = "Cannot automatically construct VectorIndex for nested array. Invalid data array element found." with self.assertRaisesWith(ValueError, msg): - table.add_column(name='qux', - description='qux column', - data=expected, - index=3) # Too many index levels given + table.add_column( + name="qux", description="qux column", data=expected, index=3 + ) # Too many index levels given # Asking for too few indexes will work here but should then later fail on write - msg = ("Cannot automatically construct VectorIndex for nested array. " - "Column data contains arrays as cell values. Please check the 'data' and 'index' parameters.") + msg = ( + "Cannot automatically construct VectorIndex for nested array. Column data" + " contains arrays as cell values. Please check the 'data' and 'index'" + " parameters." + ) with self.assertRaisesWith(ValueError, msg + " 'index=1' may be too small for the given data."): - table.add_column(name='qux', - description='qux column', - data=expected, - index=1) - with self.assertRaisesWith(ValueError, msg + " 'index=True' may be too small for the given data."): - table.add_column(name='qux', - description='qux column', - data=expected, - index=True) + table.add_column(name="qux", description="qux column", data=expected, index=1) + with self.assertRaisesWith( + ValueError, + msg + " 'index=True' may be too small for the given data.", + ): + table.add_column(name="qux", description="qux column", data=expected, index=True) def test_add_column_auto_multi_index_int_with_empty_slots(self): """ @@ -331,105 +386,64 @@ def test_add_column_auto_multi_index_int_with_empty_slots(self): multi-indexed column. """ table = self.with_spec() - table.add_row(foo=5, bar=50.0, baz='lizard') - table.add_row(foo=5, bar=50.0, baz='lizard') - expected = [[[], []], - [[], []]] - table.add_column(name='qux', - description='qux column', - data=expected, - index=2) - self.assertListEqual(table['qux'][:], expected) + table.add_row(foo=5, bar=50.0, baz="lizard") + table.add_row(foo=5, bar=50.0, baz="lizard") + expected = [[[], []], [[], []]] + table.add_column(name="qux", description="qux column", data=expected, index=2) + self.assertListEqual(table["qux"][:], expected) self.assertListEqual(table.qux_index_index.data, [2, 4]) self.assertListEqual(table.qux_index.data, [0, 0, 0, 0]) # Add more rows after we created the column - table.add_row(foo=5, bar=50.0, baz='lizard', qux=[[10, 11, 12], ]) - self.assertListEqual(table['qux'][:], expected + [[[10, 11, 12], ]]) + table.add_row( + foo=5, + bar=50.0, + baz="lizard", + qux=[ + [10, 11, 12], + ], + ) + self.assertListEqual( + table["qux"][:], + expected + + [ + [ + [10, 11, 12], + ] + ], + ) self.assertListEqual(table.qux_index_index.data, [2, 4, 5]) self.assertListEqual(table.qux_index.data, [0, 0, 0, 0, 3]) def test_auto_multi_index_required(self): - class TestTable(DynamicTable): - __columns__ = (dict(name='qux', description='qux column', index=3, required=True),) + __columns__ = (dict(name="qux", description="qux column", index=3, required=True),) - table = TestTable(name='table_name', description='table_description') + table = TestTable(name="table_name", description="table_description") self.assertIsInstance(table.qux, VectorData) # check that the attribute is set self.assertIsInstance(table.qux_index, VectorIndex) # check that the attribute is set self.assertIsInstance(table.qux_index_index, VectorIndex) # check that the attribute is set self.assertIsInstance(table.qux_index_index_index, VectorIndex) # check that the attribute is set - table.add_row( - qux=[ - [ - [1, 2, 3], - [1, 2, 3, 4] - ] - ] - ) - table.add_row( - qux=[ - [ - [1, 2] - ] - ] - ) + table.add_row(qux=[[[1, 2, 3], [1, 2, 3, 4]]]) + table.add_row(qux=[[[1, 2]]]) - expected = [ - [ - [ - [1, 2, 3], - [1, 2, 3, 4] - ] - ], - [ - [ - [1, 2] - ] - ] - ] - self.assertListEqual(table['qux'][:], expected) + expected = [[[[1, 2, 3], [1, 2, 3, 4]]], [[[1, 2]]]] + self.assertListEqual(table["qux"][:], expected) self.assertEqual(table.qux_index_index_index.data, [1, 2]) def test_auto_multi_index(self): - class TestTable(DynamicTable): - __columns__ = (dict(name='qux', description='qux column', index=3),) # this is optional + __columns__ = (dict(name="qux", description="qux column", index=3),) # this is optional - table = TestTable(name='table_name', description='table_description') + table = TestTable(name="table_name", description="table_description") self.assertIsNone(table.qux) # these are reserved as attributes but not yet initialized self.assertIsNone(table.qux_index) self.assertIsNone(table.qux_index_index) self.assertIsNone(table.qux_index_index_index) - table.add_row( - qux=[ - [ - [1, 2, 3], - [1, 2, 3, 4] - ] - ] - ) - table.add_row( - qux=[ - [ - [1, 2] - ] - ] - ) + table.add_row(qux=[[[1, 2, 3], [1, 2, 3, 4]]]) + table.add_row(qux=[[[1, 2]]]) - expected = [ - [ - [ - [1, 2, 3], - [1, 2, 3, 4] - ] - ], - [ - [ - [1, 2] - ] - ] - ] - self.assertListEqual(table['qux'][:], expected) + expected = [[[[1, 2, 3], [1, 2, 3, 4]]], [[[1, 2]]]] + self.assertListEqual(table["qux"][:], expected) self.assertEqual(table.qux_index_index_index.data, [1, 2]) def test_getitem_row_num(self): @@ -437,7 +451,7 @@ def test_getitem_row_num(self): self.add_rows(table) row = table[2] self.assertTupleEqual(row.shape, (1, 3)) - self.assertTupleEqual(tuple(row.iloc[0]), (3, 30.0, 'bird')) + self.assertTupleEqual(tuple(row.iloc[0]), (3, 30.0, "bird")) def test_getitem_row_slice(self): table = self.with_spec() @@ -445,7 +459,7 @@ def test_getitem_row_slice(self): rows = table[1:3] self.assertIsInstance(rows, pd.DataFrame) self.assertTupleEqual(rows.shape, (2, 3)) - self.assertTupleEqual(tuple(rows.iloc[1]), (3, 30.0, 'bird')) + self.assertTupleEqual(tuple(rows.iloc[1]), (3, 30.0, "bird")) def test_getitem_row_slice_with_step(self): table = self.with_spec() @@ -455,7 +469,7 @@ def test_getitem_row_slice_with_step(self): self.assertTupleEqual(rows.shape, (3, 3)) self.assertEqual(rows.iloc[2][0], 5) self.assertEqual(rows.iloc[2][1], 50.0) - self.assertEqual(rows.iloc[2][2], 'lizard') + self.assertEqual(rows.iloc[2][2], "lizard") def test_getitem_invalid_keytype(self): table = self.with_spec() @@ -466,7 +480,7 @@ def test_getitem_invalid_keytype(self): def test_getitem_col_select_and_row_slice(self): table = self.with_spec() self.add_rows(table) - col = table[1:3, 'bar'] + col = table[1:3, "bar"] self.assertEqual(len(col), 2) self.assertEqual(col[0], 20.0) self.assertEqual(col[1], 30.0) @@ -474,7 +488,7 @@ def test_getitem_col_select_and_row_slice(self): def test_getitem_column(self): table = self.with_spec() self.add_rows(table) - col = table['bar'] + col = table["bar"] self.assertEqual(col[0], 10.0) self.assertEqual(col[1], 20.0) self.assertEqual(col[2], 30.0) @@ -486,21 +500,21 @@ def test_getitem_list_idx(self): self.add_rows(table) row = table[[0, 2, 4]] self.assertEqual(len(row), 3) - self.assertTupleEqual(tuple(row.iloc[0]), (1, 10.0, 'cat')) - self.assertTupleEqual(tuple(row.iloc[1]), (3, 30.0, 'bird')) - self.assertTupleEqual(tuple(row.iloc[2]), (5, 50.0, 'lizard')) + self.assertTupleEqual(tuple(row.iloc[0]), (1, 10.0, "cat")) + self.assertTupleEqual(tuple(row.iloc[1]), (3, 30.0, "bird")) + self.assertTupleEqual(tuple(row.iloc[2]), (5, 50.0, "lizard")) def test_getitem_point_idx_colname(self): table = self.with_spec() self.add_rows(table) - val = table[2, 'bar'] + val = table[2, "bar"] self.assertEqual(val, 30.0) def test_getitem_point_idx(self): table = self.with_spec() self.add_rows(table) row = table[2] - self.assertTupleEqual(tuple(row.iloc[0]), (3, 30.0, 'bird')) + self.assertTupleEqual(tuple(row.iloc[0]), (3, 30.0, "bird")) def test_getitem_point_idx_colidx(self): table = self.with_spec() @@ -509,12 +523,12 @@ def test_getitem_point_idx_colidx(self): self.assertEqual(val, 30.0) def test_pandas_roundtrip(self): - df = pd.DataFrame({ - 'a': [1, 2, 3, 4], - 'b': ['a', 'b', 'c', '4'] - }, index=pd.Index(name='an_index', data=[2, 4, 6, 8])) + df = pd.DataFrame( + {"a": [1, 2, 3, 4], "b": ["a", "b", "c", "4"]}, + index=pd.Index(name="an_index", data=[2, 4, 6, 8]), + ) - table = DynamicTable.from_dataframe(df, 'foo') + table = DynamicTable.from_dataframe(df, "foo") obtained = table.to_dataframe() self.assertTrue(df.equals(obtained)) @@ -522,45 +536,44 @@ def test_to_dataframe(self): table = self.with_columns_and_data() data = OrderedDict() for name in table.colnames: - if name == 'foo': + if name == "foo": data[name] = [1, 2, 3, 4, 5] - elif name == 'bar': + elif name == "bar": data[name] = [10.0, 20.0, 30.0, 40.0, 50.0] - elif name == 'baz': - data[name] = ['cat', 'dog', 'bird', 'fish', 'lizard'] + elif name == "baz": + data[name] = ["cat", "dog", "bird", "fish", "lizard"] expected_df = pd.DataFrame(data) obtained_df = table.to_dataframe() self.assertTrue(expected_df.equals(obtained_df)) def test_from_dataframe(self): - df = pd.DataFrame({ - 'foo': [1, 2, 3, 4, 5], - 'bar': [10.0, 20.0, 30.0, 40.0, 50.0], - 'baz': ['cat', 'dog', 'bird', 'fish', 'lizard'] - }).loc[:, ('foo', 'bar', 'baz')] - - obtained_table = DynamicTable.from_dataframe(df, 'test') + df = pd.DataFrame( + { + "foo": [1, 2, 3, 4, 5], + "bar": [10.0, 20.0, 30.0, 40.0, 50.0], + "baz": ["cat", "dog", "bird", "fish", "lizard"], + } + ).loc[:, ("foo", "bar", "baz")] + + obtained_table = DynamicTable.from_dataframe(df, "test") self.check_table(obtained_table) def test_from_dataframe_eq(self): - expected = DynamicTable(name='test_table', description='the expected table') - expected.add_column('a', '2d column') - expected.add_column('b', '1d column') - expected.add_row(a=[1, 2, 3], b='4') - expected.add_row(a=[1, 2, 3], b='5') - expected.add_row(a=[1, 2, 3], b='6') - - df = pd.DataFrame({ - 'a': [[1, 2, 3], - [1, 2, 3], - [1, 2, 3]], - 'b': ['4', '5', '6'] - }) - coldesc = {'a': '2d column', 'b': '1d column'} - received = DynamicTable.from_dataframe(df, - 'test_table', - table_description='the expected table', - column_descriptions=coldesc) + expected = DynamicTable(name="test_table", description="the expected table") + expected.add_column("a", "2d column") + expected.add_column("b", "1d column") + expected.add_row(a=[1, 2, 3], b="4") + expected.add_row(a=[1, 2, 3], b="5") + expected.add_row(a=[1, 2, 3], b="6") + + df = pd.DataFrame({"a": [[1, 2, 3], [1, 2, 3], [1, 2, 3]], "b": ["4", "5", "6"]}) + coldesc = {"a": "2d column", "b": "1d column"} + received = DynamicTable.from_dataframe( + df, + "test_table", + table_description="the expected table", + column_descriptions=coldesc, + ) self.assertContainerEqual(expected, received, ignore_hdmf_attrs=True) def test_from_dataframe_dup_attr(self): @@ -569,80 +582,115 @@ def test_from_dataframe_dup_attr(self): DynamicTable attribute (e.g., description), that the table can be created, the existing attribute is not altered, a warning is raised, and the column can still be accessed using the table[col_name] syntax. """ - df = pd.DataFrame({ - 'parent': [1, 2, 3, 4, 5], - 'name': [10.0, 20.0, 30.0, 40.0, 50.0], - 'description': ['cat', 'dog', 'bird', 'fish', 'lizard'] - }) + df = pd.DataFrame( + { + "parent": [1, 2, 3, 4, 5], + "name": [10.0, 20.0, 30.0, 40.0, 50.0], + "description": ["cat", "dog", "bird", "fish", "lizard"], + } + ) # technically there are three separate warnings but just catch one here - msg1 = ("An attribute 'parent' already exists on DynamicTable 'test' so this column cannot be accessed " - "as an attribute, e.g., table.parent; it can only be accessed using other methods, e.g., " - "table['parent'].") + msg1 = ( + "An attribute 'parent' already exists on DynamicTable 'test' so this column" + " cannot be accessed as an attribute, e.g., table.parent; it can only be" + " accessed using other methods, e.g., table['parent']." + ) with self.assertWarnsWith(UserWarning, msg1): - table = DynamicTable.from_dataframe(df, 'test') - self.assertEqual(table.name, 'test') - self.assertEqual(table.description, '') + table = DynamicTable.from_dataframe(df, "test") + self.assertEqual(table.name, "test") + self.assertEqual(table.description, "") self.assertIsNone(table.parent) - self.assertEqual(table['name'].name, 'name') - self.assertEqual(table['description'].name, 'description') - self.assertEqual(table['parent'].name, 'parent') + self.assertEqual(table["name"].name, "name") + self.assertEqual(table["description"].name, "description") + self.assertEqual(table["parent"].name, "parent") def test_missing_columns(self): table = self.with_spec() with self.assertRaises(ValueError): - table.add_row({'bar': 60.0, 'foo': [6]}, None) + table.add_row({"bar": 60.0, "foo": [6]}, None) def test_enforce_unique_id_error(self): table = self.with_spec() - table.add_row(id=10, data={'foo': 1, 'bar': 10.0, 'baz': 'cat'}, enforce_unique_id=True) + table.add_row( + id=10, + data={"foo": 1, "bar": 10.0, "baz": "cat"}, + enforce_unique_id=True, + ) with self.assertRaises(ValueError): - table.add_row(id=10, data={'foo': 1, 'bar': 10.0, 'baz': 'cat'}, enforce_unique_id=True) + table.add_row( + id=10, + data={"foo": 1, "bar": 10.0, "baz": "cat"}, + enforce_unique_id=True, + ) def test_not_enforce_unique_id_error(self): table = self.with_spec() - table.add_row(id=10, data={'foo': 1, 'bar': 10.0, 'baz': 'cat'}, enforce_unique_id=False) + table.add_row( + id=10, + data={"foo": 1, "bar": 10.0, "baz": "cat"}, + enforce_unique_id=False, + ) try: - table.add_row(id=10, data={'foo': 1, 'bar': 10.0, 'baz': 'cat'}, enforce_unique_id=False) + table.add_row( + id=10, + data={"foo": 1, "bar": 10.0, "baz": "cat"}, + enforce_unique_id=False, + ) except ValueError as e: self.fail("add row with non unique id raised error %s" % str(e)) def test_bad_id_type_error(self): table = self.with_spec() with self.assertRaises(TypeError): - table.add_row(id=10.1, data={'foo': 1, 'bar': 10.0, 'baz': 'cat'}, enforce_unique_id=True) + table.add_row( + id=10.1, + data={"foo": 1, "bar": 10.0, "baz": "cat"}, + enforce_unique_id=True, + ) with self.assertRaises(TypeError): - table.add_row(id='str', data={'foo': 1, 'bar': 10.0, 'baz': 'cat'}, enforce_unique_id=True) + table.add_row( + id="str", + data={"foo": 1, "bar": 10.0, "baz": "cat"}, + enforce_unique_id=True, + ) def test_extra_columns(self): table = self.with_spec() with self.assertRaises(ValueError): - table.add_row({'bar': 60.0, 'foo': 6, 'baz': 'oryx', 'qax': -1}, None) + table.add_row({"bar": 60.0, "foo": 6, "baz": "oryx", "qax": -1}, None) def test_nd_array_to_df(self): data = np.array([[1, 1, 1], [2, 2, 2], [3, 3, 3]]) - col = VectorData(name='data', description='desc', data=data) - df = DynamicTable(name='test', description='desc', id=np.arange(3, dtype='int'), - columns=(col, )).to_dataframe() - df2 = pd.DataFrame({'data': [x for x in data]}, - index=pd.Index(name='id', data=[0, 1, 2])) + col = VectorData(name="data", description="desc", data=data) + df = DynamicTable( + name="test", + description="desc", + id=np.arange(3, dtype="int"), + columns=(col,), + ).to_dataframe() + df2 = pd.DataFrame( + {"data": [x for x in data]}, + index=pd.Index(name="id", data=[0, 1, 2]), + ) pd.testing.assert_frame_equal(df, df2) def test_id_search(self): table = self.with_spec() - data = [{'foo': 1, 'bar': 10.0, 'baz': 'cat'}, - {'foo': 2, 'bar': 20.0, 'baz': 'dog'}, - {'foo': 3, 'bar': 30.0, 'baz': 'bird'}, # id=2 - {'foo': 4, 'bar': 40.0, 'baz': 'fish'}, - {'foo': 5, 'bar': 50.0, 'baz': 'lizard'} # id=4 - ] + data = [ + {"foo": 1, "bar": 10.0, "baz": "cat"}, + {"foo": 2, "bar": 20.0, "baz": "dog"}, + {"foo": 3, "bar": 30.0, "baz": "bird"}, # id=2 + {"foo": 4, "bar": 40.0, "baz": "fish"}, + {"foo": 5, "bar": 50.0, "baz": "lizard"}, # id=4 + ] for i in data: table.add_row(i) res = table[table.id == [2, 4]] self.assertEqual(len(res), 2) - self.assertTupleEqual(tuple(res.iloc[0]), (3, 30.0, 'bird')) - self.assertTupleEqual(tuple(res.iloc[1]), (5, 50.0, 'lizard')) + self.assertTupleEqual(tuple(res.iloc[0]), (3, 30.0, "bird")) + self.assertTupleEqual(tuple(res.iloc[1]), (5, 50.0, "lizard")) def test_repr(self): table = self.with_spec() @@ -661,28 +709,37 @@ def test_repr(self): def test_add_column_existing_attr(self): table = self.with_table_columns() - attrs = ['name', 'description', 'parent', 'id', 'fields'] # just a few + attrs = ["name", "description", "parent", "id", "fields"] # just a few for attr in attrs: with self.subTest(attr=attr): - msg = ("An attribute '%s' already exists on DynamicTable 'with_table_columns' so this column cannot be " - "accessed as an attribute, e.g., table.%s; it can only be accessed using other methods, " - "e.g., table['%s']." % (attr, attr, attr)) + msg = ( + "An attribute '%s' already exists on DynamicTable" + " 'with_table_columns' so this column cannot be accessed as an" + " attribute, e.g., table.%s; it can only be accessed using other" + " methods, e.g., table['%s']." % (attr, attr, attr) + ) with self.assertWarnsWith(UserWarning, msg): - table.add_column(name=attr, description='') + table.add_column(name=attr, description="") def test_init_columns_existing_attr(self): - attrs = ['name', 'description', 'parent', 'id', 'fields'] # just a few + attrs = ["name", "description", "parent", "id", "fields"] # just a few for attr in attrs: with self.subTest(attr=attr): - cols = [VectorData(name=attr, description='')] - msg = ("An attribute '%s' already exists on DynamicTable 'test_table' so this column cannot be " - "accessed as an attribute, e.g., table.%s; it can only be accessed using other methods, " - "e.g., table['%s']." % (attr, attr, attr)) + cols = [VectorData(name=attr, description="")] + msg = ( + "An attribute '%s' already exists on DynamicTable 'test_table' so" + " this column cannot be accessed as an attribute, e.g., table.%s;" + " it can only be accessed using other methods, e.g., table['%s']." % (attr, attr, attr) + ) with self.assertWarnsWith(UserWarning, msg): - DynamicTable(name="test_table", description='a test table', columns=cols) + DynamicTable( + name="test_table", + description="a test table", + columns=cols, + ) def test_colnames_none(self): - table = DynamicTable(name='table0', description='an example table') + table = DynamicTable(name="table0", description="an example table") self.assertTupleEqual(table.colnames, tuple()) self.assertTupleEqual(table.columns, tuple()) @@ -694,7 +751,7 @@ def test_index_out_of_bounds(self): def test_no_df_nested(self): table = self.with_columns_and_data() - msg = 'DynamicTable.get() with df=False and index=False is not yet supported.' + msg = "DynamicTable.get() with df=False and index=False is not yet supported." with self.assertRaisesWith(ValueError, msg): table.get(0, df=False, index=False) @@ -702,118 +759,154 @@ def test_multidim_col(self): multidim_data = [ [[1, 2], [3, 4], [5, 6]], ((1, 2), (3, 4), (5, 6)), - [(1, 'a', True), (2, 'b', False), (3, 'c', True)], + [(1, "a", True), (2, "b", False), (3, "c", True)], ] columns = [ - VectorData(name=s['name'], description=s['description'], data=d) - for s, d in zip(self.spec, multidim_data) + VectorData(name=s["name"], description=s["description"], data=d) for s, d in zip(self.spec, multidim_data) ] - table = DynamicTable(name="with_columns_and_data", description='a test table', columns=columns) + table = DynamicTable( + name="with_columns_and_data", + description="a test table", + columns=columns, + ) df = table.to_dataframe() - df2 = pd.DataFrame({'foo': multidim_data[0], - 'bar': multidim_data[1], - 'baz': multidim_data[2]}, - index=pd.Index(name='id', data=[0, 1, 2])) + df2 = pd.DataFrame( + { + "foo": multidim_data[0], + "bar": multidim_data[1], + "baz": multidim_data[2], + }, + index=pd.Index(name="id", data=[0, 1, 2]), + ) pd.testing.assert_frame_equal(df, df2) - df3 = pd.DataFrame({'foo': [multidim_data[0][0]], - 'bar': [multidim_data[1][0]], - 'baz': [multidim_data[2][0]]}, - index=pd.Index(name='id', data=[0])) + df3 = pd.DataFrame( + { + "foo": [multidim_data[0][0]], + "bar": [multidim_data[1][0]], + "baz": [multidim_data[2][0]], + }, + index=pd.Index(name="id", data=[0]), + ) pd.testing.assert_frame_equal(table.get(0), df3) def test_multidim_col_one_elt_list(self): data = [[1, 2]] - col = VectorData(name='data', description='desc', data=data) - table = DynamicTable(name='test', description='desc', columns=(col, )) + col = VectorData(name="data", description="desc", data=data) + table = DynamicTable(name="test", description="desc", columns=(col,)) df = table.to_dataframe() - df2 = pd.DataFrame({'data': [x for x in data]}, - index=pd.Index(name='id', data=[0])) + df2 = pd.DataFrame({"data": [x for x in data]}, index=pd.Index(name="id", data=[0])) pd.testing.assert_frame_equal(df, df2) pd.testing.assert_frame_equal(table.get(0), df2) def test_multidim_col_one_elt_tuple(self): data = [(1, 2)] - col = VectorData(name='data', description='desc', data=data) - table = DynamicTable(name='test', description='desc', columns=(col, )) + col = VectorData(name="data", description="desc", data=data) + table = DynamicTable(name="test", description="desc", columns=(col,)) df = table.to_dataframe() - df2 = pd.DataFrame({'data': [x for x in data]}, - index=pd.Index(name='id', data=[0])) + df2 = pd.DataFrame({"data": [x for x in data]}, index=pd.Index(name="id", data=[0])) pd.testing.assert_frame_equal(df, df2) pd.testing.assert_frame_equal(table.get(0), df2) def test_eq(self): columns = [ - VectorData(name=s['name'], description=s['description'], data=d) - for s, d in zip(self.spec, self.data) + VectorData(name=s["name"], description=s["description"], data=d) for s, d in zip(self.spec, self.data) ] - test_table = DynamicTable(name="with_columns_and_data", description='a test table', columns=columns) + test_table = DynamicTable( + name="with_columns_and_data", + description="a test table", + columns=columns, + ) table = self.with_columns_and_data() self.assertTrue(table == test_table) def test_eq_from_df(self): - df = pd.DataFrame({ - 'foo': [1, 2, 3, 4, 5], - 'bar': [10.0, 20.0, 30.0, 40.0, 50.0], - 'baz': ['cat', 'dog', 'bird', 'fish', 'lizard'] - }).loc[:, ('foo', 'bar', 'baz')] - - test_table = DynamicTable.from_dataframe(df, 'with_columns_and_data', table_description='a test table') + df = pd.DataFrame( + { + "foo": [1, 2, 3, 4, 5], + "bar": [10.0, 20.0, 30.0, 40.0, 50.0], + "baz": ["cat", "dog", "bird", "fish", "lizard"], + } + ).loc[:, ("foo", "bar", "baz")] + + test_table = DynamicTable.from_dataframe(df, "with_columns_and_data", table_description="a test table") table = self.with_columns_and_data() self.assertTrue(table == test_table) def test_eq_diff_missing_col(self): columns = [ - VectorData(name=s['name'], description=s['description'], data=d) - for s, d in zip(self.spec, self.data) + VectorData(name=s["name"], description=s["description"], data=d) for s, d in zip(self.spec, self.data) ] del columns[-1] - test_table = DynamicTable(name="with_columns_and_data", description='a test table', columns=columns) + test_table = DynamicTable( + name="with_columns_and_data", + description="a test table", + columns=columns, + ) table = self.with_columns_and_data() self.assertFalse(table == test_table) def test_eq_diff_name(self): columns = [ - VectorData(name=s['name'], description=s['description'], data=d) - for s, d in zip(self.spec, self.data) + VectorData(name=s["name"], description=s["description"], data=d) for s, d in zip(self.spec, self.data) ] - test_table = DynamicTable(name="wrong name", description='a test table', columns=columns) + test_table = DynamicTable(name="wrong name", description="a test table", columns=columns) table = self.with_columns_and_data() self.assertFalse(table == test_table) def test_eq_diff_desc(self): columns = [ - VectorData(name=s['name'], description=s['description'], data=d) - for s, d in zip(self.spec, self.data) + VectorData(name=s["name"], description=s["description"], data=d) for s, d in zip(self.spec, self.data) ] - test_table = DynamicTable(name="with_columns_and_data", description='wrong description', columns=columns) + test_table = DynamicTable( + name="with_columns_and_data", + description="wrong description", + columns=columns, + ) table = self.with_columns_and_data() self.assertFalse(table == test_table) def test_eq_bad_type(self): - container = Container('test_container') + container = Container("test_container") table = self.with_columns_and_data() self.assertFalse(table == container) class TestDynamicTableRoundTrip(H5RoundTripMixin, TestCase): - def setUpContainer(self): - table = DynamicTable(name='table0', description='an example table') - table.add_column(name='foo', description='an int column') - table.add_column(name='bar', description='a float column') - table.add_column(name='baz', description='a string column') - table.add_column(name='qux', description='a boolean column') - table.add_column(name='corge', description='a doubly indexed int column', index=2) - table.add_column(name='quux', description='an enum column', enum=True) - table.add_row(foo=27, bar=28.0, baz="cat", corge=[[1, 2, 3], [4, 5, 6]], qux=True, quux='a') - table.add_row(foo=37, bar=38.0, baz="dog", corge=[[11, 12, 13], [14, 15, 16]], qux=False, quux='b') - table.add_column(name='agv', description='a column with autogenerated multi vector index', - data=[[[1, 2, 3], [4, 5]], [[6, 7], [8, 9, 10]]], index=2) + table = DynamicTable(name="table0", description="an example table") + table.add_column(name="foo", description="an int column") + table.add_column(name="bar", description="a float column") + table.add_column(name="baz", description="a string column") + table.add_column(name="qux", description="a boolean column") + table.add_column(name="corge", description="a doubly indexed int column", index=2) + table.add_column(name="quux", description="an enum column", enum=True) + table.add_row( + foo=27, + bar=28.0, + baz="cat", + corge=[[1, 2, 3], [4, 5, 6]], + qux=True, + quux="a", + ) + table.add_row( + foo=37, + bar=38.0, + baz="dog", + corge=[[11, 12, 13], [14, 15, 16]], + qux=False, + quux="b", + ) + table.add_column( + name="agv", + description="a column with autogenerated multi vector index", + data=[[[1, 2, 3], [4, 5]], [[6, 7], [8, 9, 10]]], + index=2, + ) return table def test_index_out_of_bounds(self): @@ -827,136 +920,158 @@ class TestEmptyDynamicTableRoundTrip(H5RoundTripMixin, TestCase): """Test roundtripping a DynamicTable with no rows and no columns.""" def setUpContainer(self): - table = DynamicTable(name='table0', description='an example table') + table = DynamicTable(name="table0", description="an example table") return table class TestDynamicTableRegion(TestCase): - def setUp(self): self.spec = [ - {'name': 'foo', 'description': 'foo column'}, - {'name': 'bar', 'description': 'bar column'}, - {'name': 'baz', 'description': 'baz column'}, + {"name": "foo", "description": "foo column"}, + {"name": "bar", "description": "bar column"}, + {"name": "baz", "description": "baz column"}, ] self.data = [ [1, 2, 3, 4, 5], [10.0, 20.0, 30.0, 40.0, 50.0], - ['cat', 'dog', 'bird', 'fish', 'lizard'] + ["cat", "dog", "bird", "fish", "lizard"], ] def with_columns_and_data(self): columns = [ - VectorData(name=s['name'], description=s['description'], data=d) - for s, d in zip(self.spec, self.data) + VectorData(name=s["name"], description=s["description"], data=d) for s, d in zip(self.spec, self.data) ] - return DynamicTable(name="with_columns_and_data", description='a test table', columns=columns) + return DynamicTable( + name="with_columns_and_data", + description="a test table", + columns=columns, + ) def test_indexed_dynamic_table_region(self): table = self.with_columns_and_data() - dynamic_table_region = DynamicTableRegion(name='dtr', data=[1, 2, 2], description='desc', table=table) + dynamic_table_region = DynamicTableRegion(name="dtr", data=[1, 2, 2], description="desc", table=table) fetch_ids = dynamic_table_region[:3].index.values self.assertListEqual(fetch_ids.tolist(), [1, 2, 2]) def test_dynamic_table_region_iteration(self): table = self.with_columns_and_data() - dynamic_table_region = DynamicTableRegion(name='dtr', data=[0, 1, 2, 3, 4], description='desc', table=table) + dynamic_table_region = DynamicTableRegion(name="dtr", data=[0, 1, 2, 3, 4], description="desc", table=table) for ii, item in enumerate(dynamic_table_region): self.assertTrue(table[ii].equals(item)) def test_dynamic_table_region_shape(self): table = self.with_columns_and_data() - dynamic_table_region = DynamicTableRegion(name='dtr', data=[0, 1, 2, 3, 4], description='desc', table=table) + dynamic_table_region = DynamicTableRegion(name="dtr", data=[0, 1, 2, 3, 4], description="desc", table=table) self.assertTupleEqual(dynamic_table_region.shape, (5, 3)) def test_dynamic_table_region_to_dataframe(self): table = self.with_columns_and_data() - dynamic_table_region = DynamicTableRegion(name='dtr', data=[0, 1, 2, 2], description='desc', table=table) + dynamic_table_region = DynamicTableRegion(name="dtr", data=[0, 1, 2, 2], description="desc", table=table) res = dynamic_table_region.to_dataframe() self.assertListEqual(res.index.tolist(), [0, 1, 2, 2]) - self.assertListEqual(res['foo'].tolist(), [1, 2, 3, 3]) - self.assertListEqual(res['bar'].tolist(), [10.0, 20.0, 30.0, 30.0]) - self.assertListEqual(res['baz'].tolist(), ['cat', 'dog', 'bird', 'bird']) + self.assertListEqual(res["foo"].tolist(), [1, 2, 3, 3]) + self.assertListEqual(res["bar"].tolist(), [10.0, 20.0, 30.0, 30.0]) + self.assertListEqual(res["baz"].tolist(), ["cat", "dog", "bird", "bird"]) def test_dynamic_table_region_to_dataframe_exclude_cols(self): table = self.with_columns_and_data() - dynamic_table_region = DynamicTableRegion(name='dtr', data=[0, 1, 2, 2], description='desc', table=table) - res = dynamic_table_region.to_dataframe(exclude={'baz', 'foo'}) + dynamic_table_region = DynamicTableRegion(name="dtr", data=[0, 1, 2, 2], description="desc", table=table) + res = dynamic_table_region.to_dataframe(exclude={"baz", "foo"}) self.assertListEqual(res.index.tolist(), [0, 1, 2, 2]) self.assertEqual(len(res.columns), 1) - self.assertListEqual(res['bar'].tolist(), [10.0, 20.0, 30.0, 30.0]) + self.assertListEqual(res["bar"].tolist(), [10.0, 20.0, 30.0, 30.0]) def test_dynamic_table_region_getitem_slice(self): table = self.with_columns_and_data() - dynamic_table_region = DynamicTableRegion(name='dtr', data=[0, 1, 2, 2], description='desc', table=table) + dynamic_table_region = DynamicTableRegion(name="dtr", data=[0, 1, 2, 2], description="desc", table=table) res = dynamic_table_region[1:3] self.assertListEqual(res.index.tolist(), [1, 2]) - self.assertListEqual(res['foo'].tolist(), [2, 3]) - self.assertListEqual(res['bar'].tolist(), [20.0, 30.0]) - self.assertListEqual(res['baz'].tolist(), ['dog', 'bird']) + self.assertListEqual(res["foo"].tolist(), [2, 3]) + self.assertListEqual(res["bar"].tolist(), [20.0, 30.0]) + self.assertListEqual(res["baz"].tolist(), ["dog", "bird"]) def test_dynamic_table_region_getitem_single_row_by_index(self): table = self.with_columns_and_data() - dynamic_table_region = DynamicTableRegion(name='dtr', data=[0, 1, 2, 2], description='desc', table=table) + dynamic_table_region = DynamicTableRegion(name="dtr", data=[0, 1, 2, 2], description="desc", table=table) res = dynamic_table_region[2] - self.assertListEqual(res.index.tolist(), [2, ]) - self.assertListEqual(res['foo'].tolist(), [3, ]) - self.assertListEqual(res['bar'].tolist(), [30.0, ]) - self.assertListEqual(res['baz'].tolist(), ['bird', ]) + self.assertListEqual( + res.index.tolist(), + [ + 2, + ], + ) + self.assertListEqual( + res["foo"].tolist(), + [ + 3, + ], + ) + self.assertListEqual( + res["bar"].tolist(), + [ + 30.0, + ], + ) + self.assertListEqual( + res["baz"].tolist(), + [ + "bird", + ], + ) def test_dynamic_table_region_getitem_single_cell(self): table = self.with_columns_and_data() - dynamic_table_region = DynamicTableRegion(name='dtr', data=[0, 1, 2, 2], description='desc', table=table) - res = dynamic_table_region[2, 'foo'] + dynamic_table_region = DynamicTableRegion(name="dtr", data=[0, 1, 2, 2], description="desc", table=table) + res = dynamic_table_region[2, "foo"] self.assertEqual(res, 3) - res = dynamic_table_region[1, 'baz'] - self.assertEqual(res, 'dog') + res = dynamic_table_region[1, "baz"] + self.assertEqual(res, "dog") def test_dynamic_table_region_getitem_slice_of_column(self): table = self.with_columns_and_data() - dynamic_table_region = DynamicTableRegion(name='dtr', data=[0, 1, 2, 2], description='desc', table=table) - res = dynamic_table_region[0:3, 'foo'] + dynamic_table_region = DynamicTableRegion(name="dtr", data=[0, 1, 2, 2], description="desc", table=table) + res = dynamic_table_region[0:3, "foo"] self.assertListEqual(res, [1, 2, 3]) - res = dynamic_table_region[1:3, 'baz'] - self.assertListEqual(res, ['dog', 'bird']) + res = dynamic_table_region[1:3, "baz"] + self.assertListEqual(res, ["dog", "bird"]) def test_dynamic_table_region_getitem_bad_index(self): table = self.with_columns_and_data() - dynamic_table_region = DynamicTableRegion(name='dtr', data=[0, 1, 2, 2], description='desc', table=table) + dynamic_table_region = DynamicTableRegion(name="dtr", data=[0, 1, 2, 2], description="desc", table=table) with self.assertRaises(ValueError): _ = dynamic_table_region[True] def test_dynamic_table_region_table_prop(self): table = self.with_columns_and_data() - dynamic_table_region = DynamicTableRegion(name='dtr', data=[0, 1, 2, 2], description='desc', table=table) + dynamic_table_region = DynamicTableRegion(name="dtr", data=[0, 1, 2, 2], description="desc", table=table) self.assertEqual(table, dynamic_table_region.table) def test_dynamic_table_region_set_table_prop(self): table = self.with_columns_and_data() - dynamic_table_region = DynamicTableRegion(name='dtr', data=[0, 1, 2, 2], description='desc') + dynamic_table_region = DynamicTableRegion(name="dtr", data=[0, 1, 2, 2], description="desc") dynamic_table_region.table = table self.assertEqual(table, dynamic_table_region.table) def test_dynamic_table_region_set_table_prop_to_none(self): table = self.with_columns_and_data() - dynamic_table_region = DynamicTableRegion(name='dtr', data=[0, 1, 2, 2], description='desc', table=table) + dynamic_table_region = DynamicTableRegion(name="dtr", data=[0, 1, 2, 2], description="desc", table=table) try: dynamic_table_region.table = None except AttributeError: self.fail("DynamicTableRegion table setter raised AttributeError unexpectedly!") - @unittest.skip('we no longer check data contents for performance reasons') + @unittest.skip("we no longer check data contents for performance reasons") def test_dynamic_table_region_set_with_bad_data(self): table = self.with_columns_and_data() # index 5 is out of range - dynamic_table_region = DynamicTableRegion(name='dtr', data=[5, 1], description='desc') + dynamic_table_region = DynamicTableRegion(name="dtr", data=[5, 1], description="desc") with self.assertRaises(IndexError): dynamic_table_region.table = table self.assertIsNone(dynamic_table_region.table) def test_repr(self): table = self.with_columns_and_data() - dynamic_table_region = DynamicTableRegion(name='dtr', data=[1, 2, 2], description='desc', table=table) + dynamic_table_region = DynamicTableRegion(name="dtr", data=[1, 2, 2], description="desc", table=table) expected = """dtr hdmf.common.table.DynamicTableRegion at 0x%d Target table: with_columns_and_data hdmf.common.table.DynamicTable at 0x%d """ @@ -965,52 +1080,59 @@ def test_repr(self): def test_no_df_nested(self): table = self.with_columns_and_data() - dynamic_table_region = DynamicTableRegion(name='dtr', data=[0, 1, 2, 2], description='desc', table=table) - msg = 'DynamicTableRegion.get() with df=False and index=False is not yet supported.' + dynamic_table_region = DynamicTableRegion(name="dtr", data=[0, 1, 2, 2], description="desc", table=table) + msg = "DynamicTableRegion.get() with df=False and index=False is not yet supported." with self.assertRaisesWith(ValueError, msg): dynamic_table_region.get(0, df=False, index=False) class DynamicTableRegionRoundTrip(H5RoundTripMixin, TestCase): - def make_tables(self): self.spec2 = [ - {'name': 'qux', 'description': 'qux column'}, - {'name': 'quz', 'description': 'quz column'}, + {"name": "qux", "description": "qux column"}, + {"name": "quz", "description": "quz column"}, ] self.data2 = [ - ['qux_1', 'qux_2'], - ['quz_1', 'quz_2'], + ["qux_1", "qux_2"], + ["quz_1", "quz_2"], ] target_columns = [ - VectorData(name=s['name'], description=s['description'], data=d) - for s, d in zip(self.spec2, self.data2) + VectorData(name=s["name"], description=s["description"], data=d) for s, d in zip(self.spec2, self.data2) ] - target_table = DynamicTable(name="target_table", - description='example table to target with a DynamicTableRegion', - columns=target_columns) + target_table = DynamicTable( + name="target_table", + description="example table to target with a DynamicTableRegion", + columns=target_columns, + ) self.spec1 = [ - {'name': 'foo', 'description': 'foo column'}, - {'name': 'bar', 'description': 'bar column'}, - {'name': 'baz', 'description': 'baz column'}, - {'name': 'dtr', 'description': 'DTR'}, + {"name": "foo", "description": "foo column"}, + {"name": "bar", "description": "bar column"}, + {"name": "baz", "description": "baz column"}, + {"name": "dtr", "description": "DTR"}, ] self.data1 = [ [1, 2, 3, 4, 5], [10.0, 20.0, 30.0, 40.0, 50.0], - ['cat', 'dog', 'bird', 'fish', 'lizard'] + ["cat", "dog", "bird", "fish", "lizard"], ] columns = [ - VectorData(name=s['name'], description=s['description'], data=d) - for s, d in zip(self.spec1, self.data1) + VectorData(name=s["name"], description=s["description"], data=d) for s, d in zip(self.spec1, self.data1) ] - columns.append(DynamicTableRegion(name='dtr', description='example DynamicTableRegion', - data=[0, 1, 1, 0, 1], table=target_table)) - table = DynamicTable(name="table_with_dtr", - description='a test table that has a DynamicTableRegion', - columns=columns) + columns.append( + DynamicTableRegion( + name="dtr", + description="example DynamicTableRegion", + data=[0, 1, 1, 0, 1], + table=target_table, + ) + ) + table = DynamicTable( + name="table_with_dtr", + description="a test table that has a DynamicTableRegion", + columns=columns, + ) return table, target_table def setUp(self): @@ -1018,79 +1140,85 @@ def setUp(self): super().setUp() def setUpContainer(self): - multi_container = SimpleMultiContainer(name='multi', containers=[self.table, self.target_table]) + multi_container = SimpleMultiContainer(name="multi", containers=[self.table, self.target_table]) return multi_container def _get(self, arg): mc = self.roundtripContainer() - table = mc.containers['table_with_dtr'] + table = mc.containers["table_with_dtr"] return table.get(arg) def _get_nested(self, arg): mc = self.roundtripContainer() - table = mc.containers['table_with_dtr'] + table = mc.containers["table_with_dtr"] return table.get(arg, index=False) def _get_nodf(self, arg): mc = self.roundtripContainer() - table = mc.containers['table_with_dtr'] + table = mc.containers["table_with_dtr"] return table.get(arg, df=False) def _getitem(self, arg): mc = self.roundtripContainer() - table = mc.containers['table_with_dtr'] + table = mc.containers["table_with_dtr"] return table[arg] def test_getitem_oor(self): - msg = 'Row index 12 out of range for DynamicTable \'table_with_dtr\' (length 5).' + msg = "Row index 12 out of range for DynamicTable 'table_with_dtr' (length 5)." with self.assertRaisesWith(IndexError, msg): self._getitem(12) def test_getitem_badcol(self): with self.assertRaises(KeyError): - self._getitem('boo') + self._getitem("boo") def _assert_two_elem_df(self, rec): - columns = ['foo', 'bar', 'baz', 'dtr'] - data = [[1, 10.0, 'cat', 0], - [2, 20.0, 'dog', 1]] - exp = pd.DataFrame(data=data, columns=columns, index=pd.Series(name='id', data=[0, 1])) + columns = ["foo", "bar", "baz", "dtr"] + data = [[1, 10.0, "cat", 0], [2, 20.0, "dog", 1]] + exp = pd.DataFrame(data=data, columns=columns, index=pd.Series(name="id", data=[0, 1])) pd.testing.assert_frame_equal(rec, exp, check_dtype=False) def _assert_one_elem_df(self, rec): - columns = ['foo', 'bar', 'baz', 'dtr'] - data = [[1, 10.0, 'cat', 0]] - exp = pd.DataFrame(data=data, columns=columns, index=pd.Series(name='id', data=[0])) + columns = ["foo", "bar", "baz", "dtr"] + data = [[1, 10.0, "cat", 0]] + exp = pd.DataFrame(data=data, columns=columns, index=pd.Series(name="id", data=[0])) pd.testing.assert_frame_equal(rec, exp, check_dtype=False) def _assert_two_elem_df_nested(self, rec): - nested_columns = ['qux', 'quz'] - nested_data = [['qux_1', 'quz_1'], ['qux_2', 'quz_2']] - nested_df = pd.DataFrame(data=nested_data, columns=nested_columns, index=pd.Series(name='id', data=[0, 1])) + nested_columns = ["qux", "quz"] + nested_data = [["qux_1", "quz_1"], ["qux_2", "quz_2"]] + nested_df = pd.DataFrame( + data=nested_data, + columns=nested_columns, + index=pd.Series(name="id", data=[0, 1]), + ) - columns = ['foo', 'bar', 'baz'] - data = [[1, 10.0, 'cat'], - [2, 20.0, 'dog']] - exp = pd.DataFrame(data=data, columns=columns, index=pd.Series(name='id', data=[0, 1])) + columns = ["foo", "bar", "baz"] + data = [[1, 10.0, "cat"], [2, 20.0, "dog"]] + exp = pd.DataFrame(data=data, columns=columns, index=pd.Series(name="id", data=[0, 1])) # remove nested dataframe and test each df separately - pd.testing.assert_frame_equal(rec['dtr'][0], nested_df.iloc[[0]]) - pd.testing.assert_frame_equal(rec['dtr'][1], nested_df.iloc[[1]]) - del rec['dtr'] + pd.testing.assert_frame_equal(rec["dtr"][0], nested_df.iloc[[0]]) + pd.testing.assert_frame_equal(rec["dtr"][1], nested_df.iloc[[1]]) + del rec["dtr"] pd.testing.assert_frame_equal(rec, exp, check_dtype=False) def _assert_one_elem_df_nested(self, rec): - nested_columns = ['qux', 'quz'] - nested_data = [['qux_1', 'quz_1'], ['qux_2', 'quz_2']] - nested_df = pd.DataFrame(data=nested_data, columns=nested_columns, index=pd.Series(name='id', data=[0, 1])) + nested_columns = ["qux", "quz"] + nested_data = [["qux_1", "quz_1"], ["qux_2", "quz_2"]] + nested_df = pd.DataFrame( + data=nested_data, + columns=nested_columns, + index=pd.Series(name="id", data=[0, 1]), + ) - columns = ['foo', 'bar', 'baz'] - data = [[1, 10.0, 'cat']] - exp = pd.DataFrame(data=data, columns=columns, index=pd.Series(name='id', data=[0])) + columns = ["foo", "bar", "baz"] + data = [[1, 10.0, "cat"]] + exp = pd.DataFrame(data=data, columns=columns, index=pd.Series(name="id", data=[0])) # remove nested dataframe and test each df separately - pd.testing.assert_frame_equal(rec['dtr'][0], nested_df.iloc[[0]]) - del rec['dtr'] + pd.testing.assert_frame_equal(rec["dtr"][0], nested_df.iloc[[0]]) + del rec["dtr"] pd.testing.assert_frame_equal(rec, exp, check_dtype=False) ##################### @@ -1151,7 +1279,7 @@ def test_get_nested_slice(self): # tests DynamicTableRegion.get, DO NOT return a DataFrame def test_get_nodf_int(self): rec = self._get_nodf(0) - exp = [0, 1, 10.0, 'cat', 0] + exp = [0, 1, 10.0, "cat", 0] self.assertListEqual(rec, exp) def _assert_list_of_ndarray_equal(self, l1, l2): @@ -1167,96 +1295,151 @@ def _assert_list_of_ndarray_equal(self, l1, l2): def test_get_nodf_list_single(self): rec = self._get_nodf([0]) - exp = [np.array([0]), np.array([1]), np.array([10.0]), np.array(['cat']), np.array([0])] + exp = [ + np.array([0]), + np.array([1]), + np.array([10.0]), + np.array(["cat"]), + np.array([0]), + ] self._assert_list_of_ndarray_equal(exp, rec) def test_get_nodf_list(self): rec = self._get_nodf([0, 1]) - exp = [np.array([0, 1]), np.array([1, 2]), np.array([10.0, 20.0]), np.array(['cat', 'dog']), np.array([0, 1])] + exp = [ + np.array([0, 1]), + np.array([1, 2]), + np.array([10.0, 20.0]), + np.array(["cat", "dog"]), + np.array([0, 1]), + ] self._assert_list_of_ndarray_equal(exp, rec) def test_get_nodf_slice(self): rec = self._get_nodf(slice(0, 2, None)) - exp = [np.array([0, 1]), np.array([1, 2]), np.array([10.0, 20.0]), np.array(['cat', 'dog']), np.array([0, 1])] + exp = [ + np.array([0, 1]), + np.array([1, 2]), + np.array([10.0, 20.0]), + np.array(["cat", "dog"]), + np.array([0, 1]), + ] self._assert_list_of_ndarray_equal(exp, rec) def test_getitem_int_str(self): """Test DynamicTableRegion.__getitem__ with (int, str).""" mc = self.roundtripContainer() - table = mc.containers['table_with_dtr'] - rec = table['dtr'][0, 'qux'] - self.assertEqual(rec, 'qux_1') + table = mc.containers["table_with_dtr"] + rec = table["dtr"][0, "qux"] + self.assertEqual(rec, "qux_1") def test_getitem_str(self): """Test DynamicTableRegion.__getitem__ with str.""" mc = self.roundtripContainer() - table = mc.containers['table_with_dtr'] - rec = table['dtr']['qux'] - self.assertIs(rec, mc.containers['target_table']['qux']) + table = mc.containers["table_with_dtr"] + rec = table["dtr"]["qux"] + self.assertIs(rec, mc.containers["target_table"]["qux"]) class TestElementIdentifiers(TestCase): - def setUp(self): - self.e = ElementIdentifiers(name='ids', data=[0, 1, 2, 3, 4]) + self.e = ElementIdentifiers(name="ids", data=[0, 1, 2, 3, 4]) def test_identifier_search_single_list(self): - a = (self.e == [1]) + a = self.e == [1] np.testing.assert_array_equal(a, [1]) def test_identifier_search_single_int(self): - a = (self.e == 2) + a = self.e == 2 np.testing.assert_array_equal(a, [2]) def test_identifier_search_single_list_not_found(self): - a = (self.e == [10]) + a = self.e == [10] np.testing.assert_array_equal(a, []) def test_identifier_search_single_int_not_found(self): - a = (self.e == 10) + a = self.e == 10 np.testing.assert_array_equal(a, []) def test_identifier_search_single_list_all_match(self): - a = (self.e == [1, 2, 3]) + a = self.e == [1, 2, 3] np.testing.assert_array_equal(a, [1, 2, 3]) def test_identifier_search_single_list_partial_match(self): - a = (self.e == [1, 2, 10]) + a = self.e == [1, 2, 10] np.testing.assert_array_equal(a, [1, 2]) - a = (self.e == [-1, 2, 10]) - np.testing.assert_array_equal(a, [2, ]) + a = self.e == [-1, 2, 10] + np.testing.assert_array_equal( + a, + [ + 2, + ], + ) def test_identifier_search_with_element_identifier(self): - a = (self.e == ElementIdentifiers(name='ids', data=[1, 2, 10])) + a = self.e == ElementIdentifiers(name="ids", data=[1, 2, 10]) np.testing.assert_array_equal(a, [1, 2]) def test_identifier_search_with_bad_ids(self): with self.assertRaises(TypeError): - _ = (self.e == 0.1) + _ = self.e == 0.1 with self.assertRaises(TypeError): - _ = (self.e == 'test') + _ = self.e == "test" class SubTable(DynamicTable): - __columns__ = ( - {'name': 'col1', 'description': 'required column', 'required': True}, - {'name': 'col2', 'description': 'optional column'}, - {'name': 'col3', 'description': 'required, indexed column', 'required': True, 'index': True}, - {'name': 'col4', 'description': 'optional, indexed column', 'index': True}, - {'name': 'col5', 'description': 'required region', 'required': True, 'table': True}, - {'name': 'col6', 'description': 'optional region', 'table': True}, - {'name': 'col7', 'description': 'required, indexed region', 'required': True, 'index': True, 'table': True}, - {'name': 'col8', 'description': 'optional, indexed region', 'index': True, 'table': True}, - {'name': 'col10', 'description': 'optional, indexed enum column', 'index': True, 'class': EnumData}, - {'name': 'col11', 'description': 'optional, enumerable column', 'enum': True, 'index': True}, + {"name": "col1", "description": "required column", "required": True}, + {"name": "col2", "description": "optional column"}, + { + "name": "col3", + "description": "required, indexed column", + "required": True, + "index": True, + }, + { + "name": "col4", + "description": "optional, indexed column", + "index": True, + }, + { + "name": "col5", + "description": "required region", + "required": True, + "table": True, + }, + {"name": "col6", "description": "optional region", "table": True}, + { + "name": "col7", + "description": "required, indexed region", + "required": True, + "index": True, + "table": True, + }, + { + "name": "col8", + "description": "optional, indexed region", + "index": True, + "table": True, + }, + { + "name": "col10", + "description": "optional, indexed enum column", + "index": True, + "class": EnumData, + }, + { + "name": "col11", + "description": "optional, enumerable column", + "enum": True, + "index": True, + }, ) class SubSubTable(SubTable): - __columns__ = ( - {'name': 'col9', 'description': 'required column', 'required': True}, + {"name": "col9", "description": "required column", "required": True}, # TODO handle edge case where subclass re-defines a column from superclass # {'name': 'col2', 'description': 'optional column subsub', 'required': True}, # make col2 required ) @@ -1267,14 +1450,14 @@ class TestDynamicTableClassColumns(TestCase): def test_init(self): """Test that required columns, and not optional columns, in __columns__ are created on init.""" - table = SubTable(name='subtable', description='subtable description') - self.assertEqual(table.colnames, ('col1', 'col3', 'col5', 'col7')) + table = SubTable(name="subtable", description="subtable description") + self.assertEqual(table.colnames, ("col1", "col3", "col5", "col7")) # test different access methods. note: table.get('col1') is equivalent to table['col1'] - self.assertEqual(table.col1.description, 'required column') - self.assertEqual(table.col3.description, 'required, indexed column') - self.assertEqual(table.col5.description, 'required region') - self.assertEqual(table.col7.description, 'required, indexed region') - self.assertEqual(table['col1'].description, 'required column') + self.assertEqual(table.col1.description, "required column") + self.assertEqual(table.col3.description, "required, indexed column") + self.assertEqual(table.col5.description, "required region") + self.assertEqual(table.col7.description, "required, indexed region") + self.assertEqual(table["col1"].description, "required column") # self.assertEqual(table['col3'].description, 'required, indexed column') # TODO this should work self.assertIsNone(table.col2) @@ -1288,263 +1471,322 @@ def test_init(self): # uninitialized optional predefined columns cannot be accessed in this manner with self.assertRaises(KeyError): - table['col2'] + table["col2"] def test_gather_columns_inheritance(self): """Test that gathering columns across a type hierarchy works.""" - table = SubSubTable(name='subtable', description='subtable description') - self.assertEqual(table.colnames, ('col1', 'col3', 'col5', 'col7', 'col9')) + table = SubSubTable(name="subtable", description="subtable description") + self.assertEqual(table.colnames, ("col1", "col3", "col5", "col7", "col9")) def test_bad_predefined_columns(self): """Test that gathering columns across a type hierarchy works.""" msg = "'__columns__' must be of type tuple, found " with self.assertRaisesWith(TypeError, msg): - class BadSubTable(DynamicTable): + class BadSubTable(DynamicTable): __columns__ = [] def test_add_req_column(self): """Test that adding a required column from __columns__ raises an error.""" - table = SubTable(name='subtable', description='subtable description') + table = SubTable(name="subtable", description="subtable description") msg = "column 'col1' already exists in SubTable 'subtable'" with self.assertRaisesWith(ValueError, msg): - table.add_column(name='col1', description='column #1') + table.add_column(name="col1", description="column #1") def test_add_req_ind_column(self): """Test that adding a required, indexed column from __columns__ raises an error.""" - table = SubTable(name='subtable', description='subtable description') + table = SubTable(name="subtable", description="subtable description") msg = "column 'col3' already exists in SubTable 'subtable'" with self.assertRaisesWith(ValueError, msg): - table.add_column(name='col3', description='column #3') + table.add_column(name="col3", description="column #3") def test_add_opt_column(self): """Test that adding an optional column from __columns__ with matching specs except for description works.""" - table = SubTable(name='subtable', description='subtable description') + table = SubTable(name="subtable", description="subtable description") - table.add_column(name='col2', description='column #2') # override __columns__ description - self.assertEqual(table.col2.description, 'column #2') + table.add_column(name="col2", description="column #2") # override __columns__ description + self.assertEqual(table.col2.description, "column #2") - table.add_column(name='col4', description='column #4', index=True) - self.assertEqual(table.col4.description, 'column #4') + table.add_column(name="col4", description="column #4", index=True) + self.assertEqual(table.col4.description, "column #4") - table.add_column(name='col6', description='column #6', table=True) - self.assertEqual(table.col6.description, 'column #6') + table.add_column(name="col6", description="column #6", table=True) + self.assertEqual(table.col6.description, "column #6") - table.add_column(name='col8', description='column #8', index=True, table=True) - self.assertEqual(table.col8.description, 'column #8') + table.add_column(name="col8", description="column #8", index=True, table=True) + self.assertEqual(table.col8.description, "column #8") - table.add_column(name='col10', description='column #10', index=True, col_cls=EnumData) + table.add_column(name="col10", description="column #10", index=True, col_cls=EnumData) self.assertIsInstance(table.col10, EnumData) - table.add_column(name='col11', description='column #11', enum=True, index=True) + table.add_column(name="col11", description="column #11", enum=True, index=True) self.assertIsInstance(table.col11, EnumData) def test_add_opt_column_mismatched_table_true(self): """Test that adding an optional column from __columns__ with non-matched table raises a warning.""" - table = SubTable(name='subtable', description='subtable description') - msg = ("Column 'col2' is predefined in SubTable with table=False which does not match the entered table " - "argument. The predefined table spec will be ignored. " - "Please ensure the new column complies with the spec. " - "This will raise an error in a future version of HDMF.") + table = SubTable(name="subtable", description="subtable description") + msg = ( + "Column 'col2' is predefined in SubTable with table=False which does not" + " match the entered table argument. The predefined table spec will be" + " ignored. Please ensure the new column complies with the spec. This will" + " raise an error in a future version of HDMF." + ) with self.assertWarnsWith(UserWarning, msg): - table.add_column(name='col2', description='column #2', table=True) - self.assertEqual(table.col2.description, 'column #2') + table.add_column(name="col2", description="column #2", table=True) + self.assertEqual(table.col2.description, "column #2") self.assertEqual(type(table.col2), DynamicTableRegion) # not VectorData def test_add_opt_column_mismatched_table_table(self): """Test that adding an optional column from __columns__ with non-matched table raises a warning.""" - table = SubTable(name='subtable', description='subtable description') - msg = ("Column 'col2' is predefined in SubTable with table=False which does not match the entered table " - "argument. The predefined table spec will be ignored. " - "Please ensure the new column complies with the spec. " - "This will raise an error in a future version of HDMF.") + table = SubTable(name="subtable", description="subtable description") + msg = ( + "Column 'col2' is predefined in SubTable with table=False which does not" + " match the entered table argument. The predefined table spec will be" + " ignored. Please ensure the new column complies with the spec. This will" + " raise an error in a future version of HDMF." + ) with self.assertWarnsWith(UserWarning, msg): - table.add_column(name='col2', description='column #2', - table=DynamicTable(name='dummy', description='dummy')) - self.assertEqual(table.col2.description, 'column #2') + table.add_column( + name="col2", + description="column #2", + table=DynamicTable(name="dummy", description="dummy"), + ) + self.assertEqual(table.col2.description, "column #2") self.assertEqual(type(table.col2), DynamicTableRegion) # not VectorData def test_add_opt_column_mismatched_index_true(self): """Test that adding an optional column from __columns__ with non-matched table raises a warning.""" - table = SubTable(name='subtable', description='subtable description') - msg = ("Column 'col2' is predefined in SubTable with index=False which does not match the entered index " - "argument. The predefined index spec will be ignored. " - "Please ensure the new column complies with the spec. " - "This will raise an error in a future version of HDMF.") + table = SubTable(name="subtable", description="subtable description") + msg = ( + "Column 'col2' is predefined in SubTable with index=False which does not" + " match the entered index argument. The predefined index spec will be" + " ignored. Please ensure the new column complies with the spec. This will" + " raise an error in a future version of HDMF." + ) with self.assertWarnsWith(UserWarning, msg): - table.add_column(name='col2', description='column #2', index=True) - self.assertEqual(table.col2.description, 'column #2') - self.assertEqual(type(table.get('col2')), VectorIndex) # not VectorData + table.add_column(name="col2", description="column #2", index=True) + self.assertEqual(table.col2.description, "column #2") + self.assertEqual(type(table.get("col2")), VectorIndex) # not VectorData def test_add_opt_column_mismatched_index_data(self): """Test that adding an optional column from __columns__ with non-matched table raises a warning.""" - table = SubTable(name='subtable', description='subtable description') - table.add_row(col1='a', col3='c', col5='e', col7='g') - table.add_row(col1='a', col3='c', col5='e', col7='g') - msg = ("Column 'col2' is predefined in SubTable with index=False which does not match the entered index " - "argument. The predefined index spec will be ignored. " - "Please ensure the new column complies with the spec. " - "This will raise an error in a future version of HDMF.") + table = SubTable(name="subtable", description="subtable description") + table.add_row(col1="a", col3="c", col5="e", col7="g") + table.add_row(col1="a", col3="c", col5="e", col7="g") + msg = ( + "Column 'col2' is predefined in SubTable with index=False which does not" + " match the entered index argument. The predefined index spec will be" + " ignored. Please ensure the new column complies with the spec. This will" + " raise an error in a future version of HDMF." + ) with self.assertWarnsWith(UserWarning, msg): - table.add_column(name='col2', description='column #2', data=[1, 2, 3], index=[1, 2]) - self.assertEqual(table.col2.description, 'column #2') - self.assertEqual(type(table.get('col2')), VectorIndex) # not VectorData + table.add_column( + name="col2", + description="column #2", + data=[1, 2, 3], + index=[1, 2], + ) + self.assertEqual(table.col2.description, "column #2") + self.assertEqual(type(table.get("col2")), VectorIndex) # not VectorData def test_add_opt_column_mismatched_col_cls(self): """Test that adding an optional column from __columns__ with non-matched table raises a warning.""" - table = SubTable(name='subtable', description='subtable description') - msg = ("Column 'col10' is predefined in SubTable with class= " - "which does not match the entered col_cls " - "argument. The predefined class spec will be ignored. " - "Please ensure the new column complies with the spec. " - "This will raise an error in a future version of HDMF.") + table = SubTable(name="subtable", description="subtable description") + msg = ( + "Column 'col10' is predefined in SubTable with class= which does not match the entered col_cls" + " argument. The predefined class spec will be ignored. Please ensure the" + " new column complies with the spec. This will raise an error in a future" + " version of HDMF." + ) with self.assertWarnsWith(UserWarning, msg): - table.add_column(name='col10', description='column #10', index=True) - self.assertEqual(table.col10.description, 'column #10') + table.add_column(name="col10", description="column #10", index=True) + self.assertEqual(table.col10.description, "column #10") self.assertEqual(type(table.col10), VectorData) - self.assertEqual(type(table.get('col10')), VectorIndex) + self.assertEqual(type(table.get("col10")), VectorIndex) def test_add_opt_column_twice(self): """Test that adding an optional column from __columns__ twice fails the second time.""" - table = SubTable(name='subtable', description='subtable description') - table.add_column(name='col2', description='column #2') + table = SubTable(name="subtable", description="subtable description") + table.add_column(name="col2", description="column #2") msg = "column 'col2' already exists in SubTable 'subtable'" with self.assertRaisesWith(ValueError, msg): - table.add_column(name='col2', description='column #2b') + table.add_column(name="col2", description="column #2b") def test_add_opt_column_after_data(self): """Test that adding an optional column from __columns__ with data works.""" - table = SubTable(name='subtable', description='subtable description') - table.add_row(col1='a', col3='c', col5='e', col7='g') - table.add_column(name='col2', description='column #2', data=('b', )) - self.assertTupleEqual(table.col2.data, ('b', )) + table = SubTable(name="subtable", description="subtable description") + table.add_row(col1="a", col3="c", col5="e", col7="g") + table.add_column(name="col2", description="column #2", data=("b",)) + self.assertTupleEqual(table.col2.data, ("b",)) def test_add_opt_ind_column_after_data(self): """Test that adding an optional, indexed column from __columns__ with data works.""" - table = SubTable(name='subtable', description='subtable description') - table.add_row(col1='a', col3='c', col5='e', col7='g') + table = SubTable(name="subtable", description="subtable description") + table.add_row(col1="a", col3="c", col5="e", col7="g") # TODO this use case is tricky and should not be allowed # table.add_column(name='col4', description='column #4', data=(('b', 'b2'), )) def test_add_row_opt_column(self): """Test that adding a row with an optional column works.""" - table = SubTable(name='subtable', description='subtable description') - table.add_row(col1='a', col2='b', col3='c', col4=('d1', 'd2'), col5='e', col7='g') - table.add_row(col1='a', col2='b2', col3='c', col4=('d3', 'd4'), col5='e', col7='g') - self.assertTupleEqual(table.colnames, ('col1', 'col3', 'col5', 'col7', 'col2', 'col4')) - self.assertEqual(table.col2.description, 'optional column') - self.assertEqual(table.col4.description, 'optional, indexed column') - self.assertListEqual(table.col2.data, ['b', 'b2']) + table = SubTable(name="subtable", description="subtable description") + table.add_row(col1="a", col2="b", col3="c", col4=("d1", "d2"), col5="e", col7="g") + table.add_row(col1="a", col2="b2", col3="c", col4=("d3", "d4"), col5="e", col7="g") + self.assertTupleEqual(table.colnames, ("col1", "col3", "col5", "col7", "col2", "col4")) + self.assertEqual(table.col2.description, "optional column") + self.assertEqual(table.col4.description, "optional, indexed column") + self.assertListEqual(table.col2.data, ["b", "b2"]) # self.assertListEqual(table.col4.data, [('d1', 'd2'), ('d3', 'd4')]) # TODO this should work def test_add_row_opt_column_after_data(self): """Test that adding a row with an optional column after adding a row without the column raises an error.""" - table = SubTable(name='subtable', description='subtable description') - table.add_row(col1='a', col3='c', col5='e', col7='g') + table = SubTable(name="subtable", description="subtable description") + table.add_row(col1="a", col3="c", col5="e", col7="g") msg = "column must have the same number of rows as 'id'" # TODO improve error message with self.assertRaisesWith(ValueError, msg): - table.add_row(col1='a', col2='b', col3='c', col5='e', col7='g') + table.add_row(col1="a", col2="b", col3="c", col5="e", col7="g") def test_init_columns_add_req_column(self): """Test that passing a required column to init works.""" - col1 = VectorData(name='col1', description='column #1') # override __columns__ description - table = SubTable(name='subtable', description='subtable description', columns=[col1]) - self.assertEqual(table.colnames, ('col1', 'col3', 'col5', 'col7')) - self.assertEqual(table.col1.description, 'column #1') - self.assertTrue(hasattr(table, 'col1')) + col1 = VectorData(name="col1", description="column #1") # override __columns__ description + table = SubTable(name="subtable", description="subtable description", columns=[col1]) + self.assertEqual(table.colnames, ("col1", "col3", "col5", "col7")) + self.assertEqual(table.col1.description, "column #1") + self.assertTrue(hasattr(table, "col1")) def test_init_columns_add_req_column_mismatch_index(self): """Test that passing a required column that does not match the predefined column specs raises an error.""" - col1 = VectorData(name='col1', description='column #1') # override __columns__ description - col1_ind = VectorIndex(name='col1_index', data=list(), target=col1) + col1 = VectorData(name="col1", description="column #1") # override __columns__ description + col1_ind = VectorIndex(name="col1_index", data=list(), target=col1) # TODO raise an error - SubTable(name='subtable', description='subtable description', columns=[col1_ind, col1]) + SubTable( + name="subtable", + description="subtable description", + columns=[col1_ind, col1], + ) def test_init_columns_add_req_column_mismatch_table(self): """Test that passing a required column that does not match the predefined column specs raises an error.""" - dummy_table = DynamicTable(name='dummy', description='dummy table') - col1 = DynamicTableRegion(name='col1', data=list(), description='column #1', table=dummy_table) + dummy_table = DynamicTable(name="dummy", description="dummy table") + col1 = DynamicTableRegion(name="col1", data=list(), description="column #1", table=dummy_table) # TODO raise an error - SubTable(name='subtable', description='subtable description', columns=[col1]) + SubTable(name="subtable", description="subtable description", columns=[col1]) def test_init_columns_add_opt_column(self): """Test that passing an optional column to init works.""" - col2 = VectorData(name='col2', description='column #2') # override __columns__ description - table = SubTable(name='subtable', description='subtable description', columns=[col2]) - self.assertEqual(table.colnames, ('col2', 'col1', 'col3', 'col5', 'col7')) - self.assertEqual(table.col2.description, 'column #2') + col2 = VectorData(name="col2", description="column #2") # override __columns__ description + table = SubTable(name="subtable", description="subtable description", columns=[col2]) + self.assertEqual(table.colnames, ("col2", "col1", "col3", "col5", "col7")) + self.assertEqual(table.col2.description, "column #2") def test_init_columns_add_dup_column(self): """Test that passing two columns with the same name raises an error.""" - col1 = VectorData(name='col1', description='column #1') # override __columns__ description - col1_ind = VectorIndex(name='col1', data=list(), target=col1) + col1 = VectorData(name="col1", description="column #1") # override __columns__ description + col1_ind = VectorIndex(name="col1", data=list(), target=col1) msg = "'columns' contains columns with duplicate names: ['col1', 'col1']" with self.assertRaisesWith(ValueError, msg): - SubTable(name='subtable', description='subtable description', columns=[col1_ind, col1]) + SubTable( + name="subtable", + description="subtable description", + columns=[col1_ind, col1], + ) class TestEnumData(TestCase): - def test_init(self): - ed = EnumData(name='cv_data', description='a test EnumData', elements=['a', 'b', 'c'], - data=np.array([0, 0, 1, 1, 2, 2])) + ed = EnumData( + name="cv_data", + description="a test EnumData", + elements=["a", "b", "c"], + data=np.array([0, 0, 1, 1, 2, 2]), + ) self.assertIsInstance(ed.elements, VectorData) def test_get(self): - ed = EnumData(name='cv_data', description='a test EnumData', elements=['a', 'b', 'c'], - data=np.array([0, 0, 1, 1, 2, 2])) + ed = EnumData( + name="cv_data", + description="a test EnumData", + elements=["a", "b", "c"], + data=np.array([0, 0, 1, 1, 2, 2]), + ) dat = ed[2] - self.assertEqual(dat, 'b') + self.assertEqual(dat, "b") dat = ed[-1] - self.assertEqual(dat, 'c') + self.assertEqual(dat, "c") dat = ed[0] - self.assertEqual(dat, 'a') + self.assertEqual(dat, "a") def test_get_list(self): - ed = EnumData(name='cv_data', description='a test EnumData', elements=['a', 'b', 'c'], - data=np.array([0, 0, 1, 1, 2, 2])) + ed = EnumData( + name="cv_data", + description="a test EnumData", + elements=["a", "b", "c"], + data=np.array([0, 0, 1, 1, 2, 2]), + ) dat = ed[[0, 1, 2]] - np.testing.assert_array_equal(dat, ['a', 'a', 'b']) + np.testing.assert_array_equal(dat, ["a", "a", "b"]) def test_get_list_join(self): - ed = EnumData(name='cv_data', description='a test EnumData', elements=['a', 'b', 'c'], - data=np.array([0, 0, 1, 1, 2, 2])) + ed = EnumData( + name="cv_data", + description="a test EnumData", + elements=["a", "b", "c"], + data=np.array([0, 0, 1, 1, 2, 2]), + ) dat = ed.get([0, 1, 2], join=True) - self.assertEqual(dat, 'aab') + self.assertEqual(dat, "aab") def test_get_list_indices(self): - ed = EnumData(name='cv_data', description='a test EnumData', elements=['a', 'b', 'c'], - data=np.array([0, 0, 1, 1, 2, 2])) + ed = EnumData( + name="cv_data", + description="a test EnumData", + elements=["a", "b", "c"], + data=np.array([0, 0, 1, 1, 2, 2]), + ) dat = ed.get([0, 1, 2], index=True) np.testing.assert_array_equal(dat, [0, 0, 1]) def test_get_2d(self): - ed = EnumData(name='cv_data', description='a test EnumData', - elements=['a', 'b', 'c'], - data=np.array([[0, 0], [1, 1], [2, 2]])) + ed = EnumData( + name="cv_data", + description="a test EnumData", + elements=["a", "b", "c"], + data=np.array([[0, 0], [1, 1], [2, 2]]), + ) dat = ed[0] - np.testing.assert_array_equal(dat, ['a', 'a']) + np.testing.assert_array_equal(dat, ["a", "a"]) def test_get_2d_w_2d(self): - ed = EnumData(name='cv_data', description='a test EnumData', - elements=['a', 'b', 'c'], - data=np.array([[0, 0], [1, 1], [2, 2]])) + ed = EnumData( + name="cv_data", + description="a test EnumData", + elements=["a", "b", "c"], + data=np.array([[0, 0], [1, 1], [2, 2]]), + ) dat = ed[[0, 1]] - np.testing.assert_array_equal(dat, [['a', 'a'], ['b', 'b']]) + np.testing.assert_array_equal(dat, [["a", "a"], ["b", "b"]]) def test_add_row(self): - ed = EnumData(name='cv_data', description='a test EnumData', elements=['a', 'b', 'c']) - ed.add_row('b') - ed.add_row('a') - ed.add_row('c') + ed = EnumData( + name="cv_data", + description="a test EnumData", + elements=["a", "b", "c"], + ) + ed.add_row("b") + ed.add_row("a") + ed.add_row("c") np.testing.assert_array_equal(ed.data, np.array([1, 0, 2], dtype=np.uint8)) def test_add_row_index(self): - ed = EnumData(name='cv_data', description='a test EnumData', elements=['a', 'b', 'c']) + ed = EnumData( + name="cv_data", + description="a test EnumData", + elements=["a", "b", "c"], + ) ed.add_row(1, index=True) ed.add_row(0, index=True) ed.add_row(2, index=True) @@ -1552,142 +1794,210 @@ def test_add_row_index(self): class TestIndexedEnumData(TestCase): - def test_init(self): - ed = EnumData(name='cv_data', description='a test EnumData', - elements=['a', 'b', 'c'], data=np.array([0, 0, 1, 1, 2, 2])) - idx = VectorIndex(name='enum_index', data=[2, 4, 6], target=ed) + ed = EnumData( + name="cv_data", + description="a test EnumData", + elements=["a", "b", "c"], + data=np.array([0, 0, 1, 1, 2, 2]), + ) + idx = VectorIndex(name="enum_index", data=[2, 4, 6], target=ed) self.assertIsInstance(ed.elements, VectorData) self.assertIsInstance(idx.target, EnumData) def test_add_row(self): - ed = EnumData(name='cv_data', description='a test EnumData', elements=['a', 'b', 'c']) - idx = VectorIndex(name='enum_index', data=list(), target=ed) - idx.add_row(['a', 'a', 'a']) - idx.add_row(['b', 'b']) - idx.add_row(['c', 'c', 'c', 'c']) - np.testing.assert_array_equal(idx[0], ['a', 'a', 'a']) - np.testing.assert_array_equal(idx[1], ['b', 'b']) - np.testing.assert_array_equal(idx[2], ['c', 'c', 'c', 'c']) + ed = EnumData( + name="cv_data", + description="a test EnumData", + elements=["a", "b", "c"], + ) + idx = VectorIndex(name="enum_index", data=list(), target=ed) + idx.add_row(["a", "a", "a"]) + idx.add_row(["b", "b"]) + idx.add_row(["c", "c", "c", "c"]) + np.testing.assert_array_equal(idx[0], ["a", "a", "a"]) + np.testing.assert_array_equal(idx[1], ["b", "b"]) + np.testing.assert_array_equal(idx[2], ["c", "c", "c", "c"]) def test_add_row_index(self): - ed = EnumData(name='cv_data', description='a test EnumData', elements=['a', 'b', 'c']) - idx = VectorIndex(name='enum_index', data=list(), target=ed) + ed = EnumData( + name="cv_data", + description="a test EnumData", + elements=["a", "b", "c"], + ) + idx = VectorIndex(name="enum_index", data=list(), target=ed) idx.add_row([0, 0, 0], index=True) idx.add_row([1, 1], index=True) idx.add_row([2, 2, 2, 2], index=True) - np.testing.assert_array_equal(idx[0], ['a', 'a', 'a']) - np.testing.assert_array_equal(idx[1], ['b', 'b']) - np.testing.assert_array_equal(idx[2], ['c', 'c', 'c', 'c']) + np.testing.assert_array_equal(idx[0], ["a", "a", "a"]) + np.testing.assert_array_equal(idx[1], ["b", "b"]) + np.testing.assert_array_equal(idx[2], ["c", "c", "c", "c"]) @unittest.skip("feature is not yet supported") def test_add_2d_row_index(self): - ed = EnumData(name='cv_data', description='a test EnumData', elements=['a', 'b', 'c']) - idx = VectorIndex(name='enum_index', data=list(), target=ed) - idx.add_row([['a', 'a'], ['a', 'a'], ['a', 'a']]) - idx.add_row([['b', 'b'], ['b', 'b']]) - idx.add_row([['c', 'c'], ['c', 'c'], ['c', 'c'], ['c', 'c']]) - np.testing.assert_array_equal(idx[0], [['a', 'a'], ['a', 'a'], ['a', 'a']]) - np.testing.assert_array_equal(idx[1], [['b', 'b'], ['b', 'b']]) - np.testing.assert_array_equal(idx[2], [['c', 'c'], ['c', 'c'], ['c', 'c'], ['c', 'c']]) + ed = EnumData( + name="cv_data", + description="a test EnumData", + elements=["a", "b", "c"], + ) + idx = VectorIndex(name="enum_index", data=list(), target=ed) + idx.add_row([["a", "a"], ["a", "a"], ["a", "a"]]) + idx.add_row([["b", "b"], ["b", "b"]]) + idx.add_row([["c", "c"], ["c", "c"], ["c", "c"], ["c", "c"]]) + np.testing.assert_array_equal(idx[0], [["a", "a"], ["a", "a"], ["a", "a"]]) + np.testing.assert_array_equal(idx[1], [["b", "b"], ["b", "b"]]) + np.testing.assert_array_equal(idx[2], [["c", "c"], ["c", "c"], ["c", "c"], ["c", "c"]]) class SelectionTestMixin: - def setUp(self): # table1 contains a non-ragged DTR and a ragged DTR, both of which point to table2 # table2 contains a non-ragged DTR and a ragged DTR, both of which point to table3 - self.table3 = DynamicTable( - name='table3', - description='a test table', - id=[20, 21, 22] - ) - self.table3.add_column('foo', 'scalar column', data=self._wrap([20.0, 21.0, 22.0])) - self.table3.add_column('bar', 'ragged column', index=self._wrap([2, 3, 6]), - data=self._wrap(['t11', 't12', 't21', 't31', 't32', 't33'])) - self.table3.add_column('baz', 'multi-dimension column', - data=self._wrap([[210.0, 211.0, 212.0], - [220.0, 221.0, 222.0], - [230.0, 231.0, 232.0]])) + self.table3 = DynamicTable(name="table3", description="a test table", id=[20, 21, 22]) + self.table3.add_column("foo", "scalar column", data=self._wrap([20.0, 21.0, 22.0])) + self.table3.add_column( + "bar", + "ragged column", + index=self._wrap([2, 3, 6]), + data=self._wrap(["t11", "t12", "t21", "t31", "t32", "t33"]), + ) + self.table3.add_column( + "baz", + "multi-dimension column", + data=self._wrap( + [ + [210.0, 211.0, 212.0], + [220.0, 221.0, 222.0], + [230.0, 231.0, 232.0], + ] + ), + ) # generate expected dataframe for table3 data = OrderedDict() - data['foo'] = [20.0, 21.0, 22.0] - data['bar'] = [['t11', 't12'], ['t21'], ['t31', 't32', 't33']] - data['baz'] = [[210.0, 211.0, 212.0], [220.0, 221.0, 222.0], [230.0, 231.0, 232.0]] + data["foo"] = [20.0, 21.0, 22.0] + data["bar"] = [["t11", "t12"], ["t21"], ["t31", "t32", "t33"]] + data["baz"] = [ + [210.0, 211.0, 212.0], + [220.0, 221.0, 222.0], + [230.0, 231.0, 232.0], + ] idx = [20, 21, 22] - self.table3_df = pd.DataFrame(data=data, index=pd.Index(name='id', data=idx)) - - self.table2 = DynamicTable( - name='table2', - description='a test table', - id=[10, 11, 12] - ) - self.table2.add_column('foo', 'scalar column', data=self._wrap([10.0, 11.0, 12.0])) - self.table2.add_column('bar', 'ragged column', index=self._wrap([2, 3, 6]), - data=self._wrap(['s11', 's12', 's21', 's31', 's32', 's33'])) - self.table2.add_column('baz', 'multi-dimension column', - data=self._wrap([[110.0, 111.0, 112.0], - [120.0, 121.0, 122.0], - [130.0, 131.0, 132.0]])) - self.table2.add_column('qux', 'DTR column', table=self.table3, data=self._wrap([0, 1, 0])) - self.table2.add_column('corge', 'ragged DTR column', index=self._wrap([2, 3, 6]), table=self.table3, - data=self._wrap([0, 1, 2, 0, 1, 2])) + self.table3_df = pd.DataFrame(data=data, index=pd.Index(name="id", data=idx)) + + self.table2 = DynamicTable(name="table2", description="a test table", id=[10, 11, 12]) + self.table2.add_column("foo", "scalar column", data=self._wrap([10.0, 11.0, 12.0])) + self.table2.add_column( + "bar", + "ragged column", + index=self._wrap([2, 3, 6]), + data=self._wrap(["s11", "s12", "s21", "s31", "s32", "s33"]), + ) + self.table2.add_column( + "baz", + "multi-dimension column", + data=self._wrap( + [ + [110.0, 111.0, 112.0], + [120.0, 121.0, 122.0], + [130.0, 131.0, 132.0], + ] + ), + ) + self.table2.add_column("qux", "DTR column", table=self.table3, data=self._wrap([0, 1, 0])) + self.table2.add_column( + "corge", + "ragged DTR column", + index=self._wrap([2, 3, 6]), + table=self.table3, + data=self._wrap([0, 1, 2, 0, 1, 2]), + ) # TODO test when ragged DTR indices are not in ascending order # generate expected dataframe for table2 *without DTR* data = OrderedDict() - data['foo'] = [10.0, 11.0, 12.0] - data['bar'] = [['s11', 's12'], ['s21'], ['s31', 's32', 's33']] - data['baz'] = [[110.0, 111.0, 112.0], [120.0, 121.0, 122.0], [130.0, 131.0, 132.0]] + data["foo"] = [10.0, 11.0, 12.0] + data["bar"] = [["s11", "s12"], ["s21"], ["s31", "s32", "s33"]] + data["baz"] = [ + [110.0, 111.0, 112.0], + [120.0, 121.0, 122.0], + [130.0, 131.0, 132.0], + ] idx = [10, 11, 12] - self.table2_df = pd.DataFrame(data=data, index=pd.Index(name='id', data=idx)) - - self.table1 = DynamicTable( - name='table1', - description='a table to test slicing', - id=[0, 1] - ) - self.table1.add_column('foo', 'scalar column', data=self._wrap([0.0, 1.0])) - self.table1.add_column('bar', 'ragged column', index=self._wrap([2, 3]), - data=self._wrap(['r11', 'r12', 'r21'])) - self.table1.add_column('baz', 'multi-dimension column', - data=self._wrap([[10.0, 11.0, 12.0], - [20.0, 21.0, 22.0]])) - self.table1.add_column('qux', 'DTR column', table=self.table2, data=self._wrap([0, 1])) - self.table1.add_column('corge', 'ragged DTR column', index=self._wrap([2, 3]), table=self.table2, - data=self._wrap([0, 1, 2])) - self.table1.add_column('barz', 'ragged column of tuples (cpd type)', index=self._wrap([2, 3]), - data=self._wrap([(1.0, 11), (2.0, 12), (3.0, 21)])) + self.table2_df = pd.DataFrame(data=data, index=pd.Index(name="id", data=idx)) + + self.table1 = DynamicTable(name="table1", description="a table to test slicing", id=[0, 1]) + self.table1.add_column("foo", "scalar column", data=self._wrap([0.0, 1.0])) + self.table1.add_column( + "bar", + "ragged column", + index=self._wrap([2, 3]), + data=self._wrap(["r11", "r12", "r21"]), + ) + self.table1.add_column( + "baz", + "multi-dimension column", + data=self._wrap([[10.0, 11.0, 12.0], [20.0, 21.0, 22.0]]), + ) + self.table1.add_column("qux", "DTR column", table=self.table2, data=self._wrap([0, 1])) + self.table1.add_column( + "corge", + "ragged DTR column", + index=self._wrap([2, 3]), + table=self.table2, + data=self._wrap([0, 1, 2]), + ) + self.table1.add_column( + "barz", + "ragged column of tuples (cpd type)", + index=self._wrap([2, 3]), + data=self._wrap([(1.0, 11), (2.0, 12), (3.0, 21)]), + ) # generate expected dataframe for table1 *without DTR* data = OrderedDict() - data['foo'] = self._wrap_check([0.0, 1.0]) - data['bar'] = [self._wrap_check(['r11', 'r12']), self._wrap_check(['r21'])] - data['baz'] = [self._wrap_check([10.0, 11.0, 12.0]), - self._wrap_check([20.0, 21.0, 22.0])] - data['barz'] = [self._wrap_check([(1.0, 11), (2.0, 12)]), self._wrap_check([(3.0, 21)])] + data["foo"] = self._wrap_check([0.0, 1.0]) + data["bar"] = [ + self._wrap_check(["r11", "r12"]), + self._wrap_check(["r21"]), + ] + data["baz"] = [ + self._wrap_check([10.0, 11.0, 12.0]), + self._wrap_check([20.0, 21.0, 22.0]), + ] + data["barz"] = [ + self._wrap_check([(1.0, 11), (2.0, 12)]), + self._wrap_check([(3.0, 21)]), + ] idx = [0, 1] - self.table1_df = pd.DataFrame(data=data, index=pd.Index(name='id', data=idx)) + self.table1_df = pd.DataFrame(data=data, index=pd.Index(name="id", data=idx)) def _check_two_rows_df(self, rec): data = OrderedDict() - data['foo'] = self._wrap_check([0.0, 1.0]) - data['bar'] = [self._wrap_check(['r11', 'r12']), self._wrap_check(['r21'])] - data['baz'] = [self._wrap_check([10.0, 11.0, 12.0]), - self._wrap_check([20.0, 21.0, 22.0])] - data['qux'] = self._wrap_check([0, 1]) - data['corge'] = [self._wrap_check([0, 1]), self._wrap_check([2])] - data['barz'] = [self._wrap_check([(1.0, 11), (2.0, 12)]), self._wrap_check([(3.0, 21)])] + data["foo"] = self._wrap_check([0.0, 1.0]) + data["bar"] = [ + self._wrap_check(["r11", "r12"]), + self._wrap_check(["r21"]), + ] + data["baz"] = [ + self._wrap_check([10.0, 11.0, 12.0]), + self._wrap_check([20.0, 21.0, 22.0]), + ] + data["qux"] = self._wrap_check([0, 1]) + data["corge"] = [self._wrap_check([0, 1]), self._wrap_check([2])] + data["barz"] = [ + self._wrap_check([(1.0, 11), (2.0, 12)]), + self._wrap_check([(3.0, 21)]), + ] idx = [0, 1] - exp = pd.DataFrame(data=data, index=pd.Index(name='id', data=idx)) + exp = pd.DataFrame(data=data, index=pd.Index(name="id", data=idx)) pd.testing.assert_frame_equal(rec, exp) def _check_two_rows_df_nested(self, rec): # first level: cache nested df cols and remove them before calling pd.testing.assert_frame_equal - qux_series = rec['qux'] - corge_series = rec['corge'] - del rec['qux'] - del rec['corge'] + qux_series = rec["qux"] + corge_series = rec["corge"] + del rec["qux"] + del rec["corge"] idx = [0, 1] pd.testing.assert_frame_equal(rec, self.table1_df.loc[idx]) @@ -1707,22 +2017,22 @@ def _check_two_rows_df_nested(self, rec): def _check_one_row_df(self, rec): data = OrderedDict() - data['foo'] = self._wrap_check([0.0]) - data['bar'] = [self._wrap_check(['r11', 'r12'])] - data['baz'] = [self._wrap_check([10.0, 11.0, 12.0])] - data['qux'] = self._wrap_check([0]) - data['corge'] = [self._wrap_check([0, 1])] - data['barz'] = [self._wrap_check([(1.0, 11), (2.0, 12)])] + data["foo"] = self._wrap_check([0.0]) + data["bar"] = [self._wrap_check(["r11", "r12"])] + data["baz"] = [self._wrap_check([10.0, 11.0, 12.0])] + data["qux"] = self._wrap_check([0]) + data["corge"] = [self._wrap_check([0, 1])] + data["barz"] = [self._wrap_check([(1.0, 11), (2.0, 12)])] idx = [0] - exp = pd.DataFrame(data=data, index=pd.Index(name='id', data=idx)) + exp = pd.DataFrame(data=data, index=pd.Index(name="id", data=idx)) pd.testing.assert_frame_equal(rec, exp) def _check_one_row_df_nested(self, rec): # first level: cache nested df cols and remove them before calling pd.testing.assert_frame_equal - qux_series = rec['qux'] - corge_series = rec['corge'] - del rec['qux'] - del rec['corge'] + qux_series = rec["qux"] + corge_series = rec["corge"] + del rec["qux"] + del rec["corge"] idx = [0] pd.testing.assert_frame_equal(rec, self.table1_df.loc[idx]) @@ -1738,10 +2048,10 @@ def _check_one_row_df_nested(self, rec): def _check_table2_first_row_qux(self, rec_qux): # second level: cache nested df cols and remove them before calling pd.testing.assert_frame_equal - qux_qux_series = rec_qux['qux'] - qux_corge_series = rec_qux['corge'] - del rec_qux['qux'] - del rec_qux['corge'] + qux_qux_series = rec_qux["qux"] + qux_corge_series = rec_qux["corge"] + del rec_qux["qux"] + del rec_qux["corge"] qux_idx = [10] pd.testing.assert_frame_equal(rec_qux, self.table2_df.loc[qux_idx]) @@ -1754,10 +2064,10 @@ def _check_table2_first_row_qux(self, rec_qux): def _check_table2_second_row_qux(self, rec_qux): # second level: cache nested df cols and remove them before calling pd.testing.assert_frame_equal - qux_qux_series = rec_qux['qux'] - qux_corge_series = rec_qux['corge'] - del rec_qux['qux'] - del rec_qux['corge'] + qux_qux_series = rec_qux["qux"] + qux_corge_series = rec_qux["corge"] + del rec_qux["qux"] + del rec_qux["corge"] qux_idx = [11] pd.testing.assert_frame_equal(rec_qux, self.table2_df.loc[qux_idx]) @@ -1770,10 +2080,10 @@ def _check_table2_second_row_qux(self, rec_qux): def _check_table2_first_row_corge(self, rec_corge): # second level: cache nested df cols and remove them before calling pd.testing.assert_frame_equal - corge_qux_series = rec_corge['qux'] - corge_corge_series = rec_corge['corge'] - del rec_corge['qux'] - del rec_corge['corge'] + corge_qux_series = rec_corge["qux"] + corge_corge_series = rec_corge["corge"] + del rec_corge["qux"] + del rec_corge["corge"] corge_idx = [10, 11] pd.testing.assert_frame_equal(rec_corge, self.table2_df.loc[corge_idx]) @@ -1788,10 +2098,10 @@ def _check_table2_first_row_corge(self, rec_corge): def _check_table2_second_row_corge(self, rec_corge): # second level: cache nested df cols and remove them before calling pd.testing.assert_frame_equal - corge_qux_series = rec_corge['qux'] - corge_corge_series = rec_corge['corge'] - del rec_corge['qux'] - del rec_corge['corge'] + corge_qux_series = rec_corge["qux"] + corge_corge_series = rec_corge["corge"] + del rec_corge["qux"] + del rec_corge["corge"] corge_idx = [12] pd.testing.assert_frame_equal(rec_corge, self.table2_df.loc[corge_idx]) @@ -1805,7 +2115,7 @@ def _check_table2_second_row_corge(self, rec_corge): def _check_two_rows_no_df(self, rec): self.assertEqual(rec[0], [0, 1]) np.testing.assert_array_equal(rec[1], self._wrap_check([0.0, 1.0])) - expected = [self._wrap_check(['r11', 'r12']), self._wrap_check(['r21'])] + expected = [self._wrap_check(["r11", "r12"]), self._wrap_check(["r21"])] self._assertNestedRaggedArrayEqual(rec[2], expected) np.testing.assert_array_equal(rec[3], self._wrap_check([[10.0, 11.0, 12.0], [20.0, 21.0, 22.0]])) np.testing.assert_array_equal(rec[4], self._wrap_check([0, 1])) @@ -1816,7 +2126,7 @@ def _check_two_rows_no_df(self, rec): def _check_one_row_no_df(self, rec): self.assertEqual(rec[0], 0) self.assertEqual(rec[1], 0.0) - np.testing.assert_array_equal(rec[2], self._wrap_check(['r11', 'r12'])) + np.testing.assert_array_equal(rec[2], self._wrap_check(["r11", "r12"])) np.testing.assert_array_equal(rec[3], self._wrap_check([10.0, 11.0, 12.0])) self.assertEqual(rec[4], 0) np.testing.assert_array_equal(rec[5], self._wrap_check([0, 1])) @@ -1826,7 +2136,7 @@ def _check_one_row_multiselect_no_df(self, rec): # difference from _check_one_row_no_df is that everything is wrapped in a list self.assertEqual(rec[0], [0]) self.assertEqual(rec[1], [0.0]) - np.testing.assert_array_equal(rec[2], [self._wrap_check(['r11', 'r12'])]) + np.testing.assert_array_equal(rec[2], [self._wrap_check(["r11", "r12"])]) np.testing.assert_array_equal(rec[3], [self._wrap_check([10.0, 11.0, 12.0])]) self.assertEqual(rec[4], [0]) np.testing.assert_array_equal(rec[5], [self._wrap_check([0, 1])]) @@ -1947,7 +2257,6 @@ def test_to_dataframe(self): class TestSelectionArray(SelectionTestMixin, TestCase): - def _wrap(self, my_list): return np.array(my_list) @@ -1956,7 +2265,6 @@ def _wrap_check(self, my_list): class TestSelectionList(SelectionTestMixin, TestCase): - def _wrap(self, my_list): return my_list @@ -1965,10 +2273,9 @@ def _wrap_check(self, my_list): class TestSelectionH5Dataset(SelectionTestMixin, TestCase): - def setUp(self): self.path = get_temp_filepath() - self.file = h5py.File(self.path, 'w') + self.file = h5py.File(self.path, "w") self.dset_counter = 0 super().setUp() @@ -1982,13 +2289,15 @@ def _wrap(self, my_list): self.dset_counter = self.dset_counter + 1 kwargs = dict() if isinstance(my_list[0], str): - kwargs['dtype'] = H5_TEXT + kwargs["dtype"] = H5_TEXT elif isinstance(my_list[0], tuple): # compound dtype # normally for cpd dtype, __resolve_dtype__ takes a list of DtypeSpec objects - cpd_type = [dict(name='cpd_float', dtype=np.dtype('float64')), - dict(name='cpd_int', dtype=np.dtype('int32'))] - kwargs['dtype'] = HDF5IO.__resolve_dtype__(cpd_type, my_list[0]) - dset = self.file.create_dataset('dset%d' % self.dset_counter, data=np.array(my_list, **kwargs)) + cpd_type = [ + dict(name="cpd_float", dtype=np.dtype("float64")), + dict(name="cpd_int", dtype=np.dtype("int32")), + ] + kwargs["dtype"] = HDF5IO.__resolve_dtype__(cpd_type, my_list[0]) + dset = self.file.create_dataset("dset%d" % self.dset_counter, data=np.array(my_list, **kwargs)) if H5PY_3 and isinstance(my_list[0], str): return StrDataset(dset, None) # return a wrapper to read data as str instead of bytes else: @@ -1999,202 +2308,261 @@ def _wrap_check(self, my_list): # getitem on h5dataset backed data will return np.array kwargs = dict() if isinstance(my_list[0], str): - kwargs['dtype'] = H5_TEXT + kwargs["dtype"] = H5_TEXT elif isinstance(my_list[0], tuple): - cpd_type = [dict(name='cpd_float', dtype=np.dtype('float64')), - dict(name='cpd_int', dtype=np.dtype('int32'))] - kwargs['dtype'] = np.dtype([(x['name'], x['dtype']) for x in cpd_type]) + cpd_type = [ + dict(name="cpd_float", dtype=np.dtype("float64")), + dict(name="cpd_int", dtype=np.dtype("int32")), + ] + kwargs["dtype"] = np.dtype([(x["name"], x["dtype"]) for x in cpd_type]) # compound dtypes with str are read as bytes, see https://github.com/h5py/h5py/issues/1751 return np.array(my_list, **kwargs) class TestVectorIndex(TestCase): - def test_init_empty(self): - foo = VectorData(name='foo', description='foo column') - foo_ind = VectorIndex(name='foo_index', target=foo, data=list()) - self.assertEqual(foo_ind.name, 'foo_index') + foo = VectorData(name="foo", description="foo column") + foo_ind = VectorIndex(name="foo_index", target=foo, data=list()) + self.assertEqual(foo_ind.name, "foo_index") self.assertEqual(foo_ind.description, "Index for VectorData 'foo'") self.assertIs(foo_ind.target, foo) self.assertListEqual(foo_ind.data, list()) def test_init_data(self): - foo = VectorData(name='foo', description='foo column', data=['a', 'b', 'c']) - foo_ind = VectorIndex(name='foo_index', target=foo, data=[2, 3]) + foo = VectorData(name="foo", description="foo column", data=["a", "b", "c"]) + foo_ind = VectorIndex(name="foo_index", target=foo, data=[2, 3]) self.assertListEqual(foo_ind.data, [2, 3]) - self.assertListEqual(foo_ind[0], ['a', 'b']) - self.assertListEqual(foo_ind[1], ['c']) + self.assertListEqual(foo_ind[0], ["a", "b"]) + self.assertListEqual(foo_ind[1], ["c"]) class TestDoubleIndex(TestCase): - def test_index(self): # row 1 has three entries # the first entry has two sub-entries # the first sub-entry has two values, the second sub-entry has one value # the second entry has one sub-entry, which has one value - foo = VectorData(name='foo', description='foo column', data=['a11', 'a12', 'a21', 'b11']) - foo_ind = VectorIndex(name='foo_index', target=foo, data=[2, 3, 4]) - foo_ind_ind = VectorIndex(name='foo_index_index', target=foo_ind, data=[2, 3]) + foo = VectorData( + name="foo", + description="foo column", + data=["a11", "a12", "a21", "b11"], + ) + foo_ind = VectorIndex(name="foo_index", target=foo, data=[2, 3, 4]) + foo_ind_ind = VectorIndex(name="foo_index_index", target=foo_ind, data=[2, 3]) - self.assertListEqual(foo_ind[0], ['a11', 'a12']) - self.assertListEqual(foo_ind[1], ['a21']) - self.assertListEqual(foo_ind[2], ['b11']) - self.assertListEqual(foo_ind_ind[0], [['a11', 'a12'], ['a21']]) - self.assertListEqual(foo_ind_ind[1], [['b11']]) + self.assertListEqual(foo_ind[0], ["a11", "a12"]) + self.assertListEqual(foo_ind[1], ["a21"]) + self.assertListEqual(foo_ind[2], ["b11"]) + self.assertListEqual(foo_ind_ind[0], [["a11", "a12"], ["a21"]]) + self.assertListEqual(foo_ind_ind[1], [["b11"]]) def test_add_vector(self): # row 1 has three entries # the first entry has two sub-entries # the first sub-entry has two values, the second sub-entry has one value # the second entry has one sub-entry, which has one value - foo = VectorData(name='foo', description='foo column', data=['a11', 'a12', 'a21', 'b11']) - foo_ind = VectorIndex(name='foo_index', target=foo, data=[2, 3, 4]) - foo_ind_ind = VectorIndex(name='foo_index_index', target=foo_ind, data=[2, 3]) + foo = VectorData( + name="foo", + description="foo column", + data=["a11", "a12", "a21", "b11"], + ) + foo_ind = VectorIndex(name="foo_index", target=foo, data=[2, 3, 4]) + foo_ind_ind = VectorIndex(name="foo_index_index", target=foo_ind, data=[2, 3]) - foo_ind_ind.add_vector([['c11', 'c12', 'c13'], ['c21', 'c22']]) + foo_ind_ind.add_vector([["c11", "c12", "c13"], ["c21", "c22"]]) - self.assertListEqual(foo.data, ['a11', 'a12', 'a21', 'b11', 'c11', 'c12', 'c13', 'c21', 'c22']) + self.assertListEqual( + foo.data, + ["a11", "a12", "a21", "b11", "c11", "c12", "c13", "c21", "c22"], + ) self.assertListEqual(foo_ind.data, [2, 3, 4, 7, 9]) - self.assertListEqual(foo_ind[3], ['c11', 'c12', 'c13']) - self.assertListEqual(foo_ind[4], ['c21', 'c22']) + self.assertListEqual(foo_ind[3], ["c11", "c12", "c13"]) + self.assertListEqual(foo_ind[4], ["c21", "c22"]) self.assertListEqual(foo_ind_ind.data, [2, 3, 5]) - self.assertListEqual(foo_ind_ind[2], [['c11', 'c12', 'c13'], ['c21', 'c22']]) + self.assertListEqual(foo_ind_ind[2], [["c11", "c12", "c13"], ["c21", "c22"]]) class TestDTDoubleIndex(TestCase): - def test_double_index(self): - foo = VectorData(name='foo', description='foo column', data=['a11', 'a12', 'a21', 'b11']) - foo_ind = VectorIndex(name='foo_index', target=foo, data=[2, 3, 4]) - foo_ind_ind = VectorIndex(name='foo_index_index', target=foo_ind, data=[2, 3]) + foo = VectorData( + name="foo", + description="foo column", + data=["a11", "a12", "a21", "b11"], + ) + foo_ind = VectorIndex(name="foo_index", target=foo, data=[2, 3, 4]) + foo_ind_ind = VectorIndex(name="foo_index_index", target=foo_ind, data=[2, 3]) - table = DynamicTable(name='table0', description='an example table', columns=[foo, foo_ind, foo_ind_ind]) + table = DynamicTable( + name="table0", + description="an example table", + columns=[foo, foo_ind, foo_ind_ind], + ) - self.assertIs(table['foo'], foo_ind_ind) + self.assertIs(table["foo"], foo_ind_ind) self.assertIs(table.foo, foo) - self.assertListEqual(table['foo'][0], [['a11', 'a12'], ['a21']]) - self.assertListEqual(table[0, 'foo'], [['a11', 'a12'], ['a21']]) - self.assertListEqual(table[1, 'foo'], [['b11']]) + self.assertListEqual(table["foo"][0], [["a11", "a12"], ["a21"]]) + self.assertListEqual(table[0, "foo"], [["a11", "a12"], ["a21"]]) + self.assertListEqual(table[1, "foo"], [["b11"]]) def test_double_index_reverse(self): - foo = VectorData(name='foo', description='foo column', data=['a11', 'a12', 'a21', 'b11']) - foo_ind = VectorIndex(name='foo_index', target=foo, data=[2, 3, 4]) - foo_ind_ind = VectorIndex(name='foo_index_index', target=foo_ind, data=[2, 3]) + foo = VectorData( + name="foo", + description="foo column", + data=["a11", "a12", "a21", "b11"], + ) + foo_ind = VectorIndex(name="foo_index", target=foo, data=[2, 3, 4]) + foo_ind_ind = VectorIndex(name="foo_index_index", target=foo_ind, data=[2, 3]) - table = DynamicTable(name='table0', description='an example table', columns=[foo_ind_ind, foo_ind, foo]) + table = DynamicTable( + name="table0", + description="an example table", + columns=[foo_ind_ind, foo_ind, foo], + ) - self.assertIs(table['foo'], foo_ind_ind) + self.assertIs(table["foo"], foo_ind_ind) self.assertIs(table.foo, foo) - self.assertListEqual(table['foo'][0], [['a11', 'a12'], ['a21']]) - self.assertListEqual(table[0, 'foo'], [['a11', 'a12'], ['a21']]) - self.assertListEqual(table[1, 'foo'], [['b11']]) + self.assertListEqual(table["foo"][0], [["a11", "a12"], ["a21"]]) + self.assertListEqual(table[0, "foo"], [["a11", "a12"], ["a21"]]) + self.assertListEqual(table[1, "foo"], [["b11"]]) def test_double_index_colnames(self): - foo = VectorData(name='foo', description='foo column', data=['a11', 'a12', 'a21', 'b11']) - foo_ind = VectorIndex(name='foo_index', target=foo, data=[2, 3, 4]) - foo_ind_ind = VectorIndex(name='foo_index_index', target=foo_ind, data=[2, 3]) - bar = VectorData(name='bar', description='bar column', data=[1, 2]) - - table = DynamicTable(name='table0', description='an example table', columns=[foo, foo_ind, foo_ind_ind, bar], - colnames=['foo', 'bar']) + foo = VectorData( + name="foo", + description="foo column", + data=["a11", "a12", "a21", "b11"], + ) + foo_ind = VectorIndex(name="foo_index", target=foo, data=[2, 3, 4]) + foo_ind_ind = VectorIndex(name="foo_index_index", target=foo_ind, data=[2, 3]) + bar = VectorData(name="bar", description="bar column", data=[1, 2]) + + table = DynamicTable( + name="table0", + description="an example table", + columns=[foo, foo_ind, foo_ind_ind, bar], + colnames=["foo", "bar"], + ) self.assertTupleEqual(table.columns, (foo_ind_ind, foo_ind, foo, bar)) def test_double_index_reverse_colnames(self): - foo = VectorData(name='foo', description='foo column', data=['a11', 'a12', 'a21', 'b11']) - foo_ind = VectorIndex(name='foo_index', target=foo, data=[2, 3, 4]) - foo_ind_ind = VectorIndex(name='foo_index_index', target=foo_ind, data=[2, 3]) - bar = VectorData(name='bar', description='bar column', data=[1, 2]) - - table = DynamicTable(name='table0', description='an example table', columns=[foo_ind_ind, foo_ind, foo, bar], - colnames=['bar', 'foo']) + foo = VectorData( + name="foo", + description="foo column", + data=["a11", "a12", "a21", "b11"], + ) + foo_ind = VectorIndex(name="foo_index", target=foo, data=[2, 3, 4]) + foo_ind_ind = VectorIndex(name="foo_index_index", target=foo_ind, data=[2, 3]) + bar = VectorData(name="bar", description="bar column", data=[1, 2]) + + table = DynamicTable( + name="table0", + description="an example table", + columns=[foo_ind_ind, foo_ind, foo, bar], + colnames=["bar", "foo"], + ) self.assertTupleEqual(table.columns, (bar, foo_ind_ind, foo_ind, foo)) class TestDTDoubleIndexSkipMiddle(TestCase): - def test_index(self): - foo = VectorData(name='foo', description='foo column', data=['a11', 'a12', 'a21', 'b11']) - foo_ind = VectorIndex(name='foo_index', target=foo, data=[2, 3, 4]) - foo_ind_ind = VectorIndex(name='foo_index_index', target=foo_ind, data=[2, 3]) + foo = VectorData( + name="foo", + description="foo column", + data=["a11", "a12", "a21", "b11"], + ) + foo_ind = VectorIndex(name="foo_index", target=foo, data=[2, 3, 4]) + foo_ind_ind = VectorIndex(name="foo_index_index", target=foo_ind, data=[2, 3]) msg = "Found VectorIndex 'foo_index_index' but not its target 'foo_index'" with self.assertRaisesWith(ValueError, msg): - DynamicTable(name='table0', description='an example table', columns=[foo_ind_ind, foo]) + DynamicTable( + name="table0", + description="an example table", + columns=[foo_ind_ind, foo], + ) class TestDynamicTableAddIndexRoundTrip(H5RoundTripMixin, TestCase): - def setUpContainer(self): - table = DynamicTable(name='table0', description='an example table') - table.add_column('foo', 'an int column', index=True) + table = DynamicTable(name="table0", description="an example table") + table.add_column("foo", "an int column", index=True) table.add_row(foo=[1, 2, 3]) return table class TestDynamicTableAddEnumRoundTrip(H5RoundTripMixin, TestCase): - def setUpContainer(self): - table = DynamicTable(name='table0', description='an example table') - table.add_column('bar', 'an enumerable column', enum=True) - table.add_row(bar='a') - table.add_row(bar='b') - table.add_row(bar='a') - table.add_row(bar='c') + table = DynamicTable(name="table0", description="an example table") + table.add_column("bar", "an enumerable column", enum=True) + table.add_row(bar="a") + table.add_row(bar="b") + table.add_row(bar="a") + table.add_row(bar="c") return table class TestDynamicTableAddEnum(TestCase): - def test_enum(self): - table = DynamicTable(name='table0', description='an example table') - table.add_column('bar', 'an enumerable column', enum=True) - table.add_row(bar='a') - table.add_row(bar='b') - table.add_row(bar='a') - table.add_row(bar='c') + table = DynamicTable(name="table0", description="an example table") + table.add_column("bar", "an enumerable column", enum=True) + table.add_row(bar="a") + table.add_row(bar="b") + table.add_row(bar="a") + table.add_row(bar="c") rec = table.to_dataframe() - exp = pd.DataFrame(data={'bar': ['a', 'b', 'a', 'c']}, index=pd.Series(name='id', data=[0, 1, 2, 3])) + exp = pd.DataFrame( + data={"bar": ["a", "b", "a", "c"]}, + index=pd.Series(name="id", data=[0, 1, 2, 3]), + ) pd.testing.assert_frame_equal(exp, rec) def test_enum_index(self): - table = DynamicTable(name='table0', description='an example table') - table.add_column('bar', 'an indexed enumerable column', enum=True, index=True) - table.add_row(bar=['a', 'a', 'a']) - table.add_row(bar=['b', 'b', 'b', 'b']) - table.add_row(bar=['c', 'c']) + table = DynamicTable(name="table0", description="an example table") + table.add_column("bar", "an indexed enumerable column", enum=True, index=True) + table.add_row(bar=["a", "a", "a"]) + table.add_row(bar=["b", "b", "b", "b"]) + table.add_row(bar=["c", "c"]) rec = table.to_dataframe() - exp = pd.DataFrame(data={'bar': [['a', 'a', 'a'], - ['b', 'b', 'b', 'b'], - ['c', 'c']]}, - index=pd.Series(name='id', data=[0, 1, 2])) + exp = pd.DataFrame( + data={"bar": [["a", "a", "a"], ["b", "b", "b", "b"], ["c", "c"]]}, + index=pd.Series(name="id", data=[0, 1, 2]), + ) pd.testing.assert_frame_equal(exp, rec) class TestDynamicTableInitIndexRoundTrip(H5RoundTripMixin, TestCase): - def setUpContainer(self): - foo = VectorData(name='foo', description='foo column', data=['a', 'b', 'c']) - foo_ind = VectorIndex(name='foo_index', target=foo, data=[2, 3]) + foo = VectorData(name="foo", description="foo column", data=["a", "b", "c"]) + foo_ind = VectorIndex(name="foo_index", target=foo, data=[2, 3]) # NOTE: on construct, columns are ordered such that indices go before data, so create the table that way # for proper comparison of the columns list - table = DynamicTable(name='table0', description='an example table', columns=[foo_ind, foo]) + table = DynamicTable( + name="table0", + description="an example table", + columns=[foo_ind, foo], + ) return table class TestDoubleIndexRoundtrip(H5RoundTripMixin, TestCase): - def setUpContainer(self): - foo = VectorData(name='foo', description='foo column', data=['a11', 'a12', 'a21', 'b11']) - foo_ind = VectorIndex(name='foo_index', target=foo, data=[2, 3, 4]) - foo_ind_ind = VectorIndex(name='foo_index_index', target=foo_ind, data=[2, 3]) + foo = VectorData( + name="foo", + description="foo column", + data=["a11", "a12", "a21", "b11"], + ) + foo_ind = VectorIndex(name="foo_index", target=foo, data=[2, 3, 4]) + foo_ind_ind = VectorIndex(name="foo_index_index", target=foo_ind, data=[2, 3]) # NOTE: on construct, columns are ordered such that indices go before data, so create the table that way # for proper comparison of the columns list - table = DynamicTable(name='table0', description='an example table', columns=[foo_ind_ind, foo_ind, foo]) + table = DynamicTable( + name="table0", + description="an example table", + columns=[foo_ind_ind, foo_ind, foo], + ) return table @@ -2212,33 +2580,32 @@ def setUpContainer(self): fletcher32=True, allow_plugin_filters=True, ) - foo = VectorData(name='foo', description='chunked column', data=self.chunked_data) - bar = VectorData(name='bar', description='chunked column', data=self.compressed_data) + foo = VectorData(name="foo", description="chunked column", data=self.chunked_data) + bar = VectorData(name="bar", description="chunked column", data=self.compressed_data) # NOTE: on construct, columns are ordered such that indices go before data, so create the table that way # for proper comparison of the columns list - table = DynamicTable(name='table0', description='an example table', columns=[foo, bar]) + table = DynamicTable(name="table0", description="an example table", columns=[foo, bar]) table.add_row(foo=1, bar=1) return table def test_roundtrip(self): super().test_roundtrip() - with h5py.File(self.filename, 'r') as f: - chunked_dset = f['foo'] + with h5py.File(self.filename, "r") as f: + chunked_dset = f["foo"] self.assertTrue(np.all(chunked_dset[:] == self.chunked_data.data)) self.assertEqual(chunked_dset.chunks, (3,)) self.assertEqual(chunked_dset.fillvalue, -1) - compressed_dset = f['bar'] + compressed_dset = f["bar"] self.assertTrue(np.all(compressed_dset[:] == self.compressed_data.data)) - self.assertEqual(compressed_dset.compression, 'gzip') + self.assertEqual(compressed_dset.compression, "gzip") self.assertEqual(compressed_dset.shuffle, True) self.assertEqual(compressed_dset.fletcher32, True) class TestDataIOIndexedColumns(H5RoundTripMixin, TestCase): - def setUpContainer(self): self.chunked_data = H5DataIO( data=np.arange(30).reshape(5, 2, 3), @@ -2252,50 +2619,56 @@ def setUpContainer(self): fletcher32=True, allow_plugin_filters=True, ) - foo = VectorData(name='foo', description='chunked column', data=self.chunked_data) - foo_ind = VectorIndex(name='foo_index', target=foo, data=[2, 3, 4]) - bar = VectorData(name='bar', description='chunked column', data=self.compressed_data) - bar_ind = VectorIndex(name='bar_index', target=bar, data=[2, 3, 4]) + foo = VectorData(name="foo", description="chunked column", data=self.chunked_data) + foo_ind = VectorIndex(name="foo_index", target=foo, data=[2, 3, 4]) + bar = VectorData(name="bar", description="chunked column", data=self.compressed_data) + bar_ind = VectorIndex(name="bar_index", target=bar, data=[2, 3, 4]) # NOTE: on construct, columns are ordered such that indices go before data, so create the table that way # for proper comparison of the columns list - table = DynamicTable(name='table0', description='an example table', columns=[foo_ind, foo, bar_ind, bar]) + table = DynamicTable( + name="table0", + description="an example table", + columns=[foo_ind, foo, bar_ind, bar], + ) # check for add_row - table.add_row(foo=np.arange(30).reshape(5, 2, 3), bar=np.arange(30).reshape(5, 2, 3)) + table.add_row( + foo=np.arange(30).reshape(5, 2, 3), + bar=np.arange(30).reshape(5, 2, 3), + ) return table def test_roundtrip(self): super().test_roundtrip() - with h5py.File(self.filename, 'r') as f: - chunked_dset = f['foo'] + with h5py.File(self.filename, "r") as f: + chunked_dset = f["foo"] self.assertTrue(np.all(chunked_dset[:] == self.chunked_data.data)) self.assertEqual(chunked_dset.chunks, (1, 1, 3)) self.assertEqual(chunked_dset.fillvalue, -1) - compressed_dset = f['bar'] + compressed_dset = f["bar"] self.assertTrue(np.all(compressed_dset[:] == self.compressed_data.data)) - self.assertEqual(compressed_dset.compression, 'gzip') + self.assertEqual(compressed_dset.compression, "gzip") self.assertEqual(compressed_dset.shuffle, True) self.assertEqual(compressed_dset.fletcher32, True) class TestDataIOIndex(H5RoundTripMixin, TestCase): - def setUpContainer(self): self.chunked_data = H5DataIO( data=np.arange(30).reshape(5, 2, 3), chunks=(1, 1, 3), fillvalue=-1, - maxshape=(None, 2, 3) + maxshape=(None, 2, 3), ) self.chunked_index_data = H5DataIO( data=np.array([2, 3, 5], dtype=np.uint), - chunks=(2, ), + chunks=(2,), fillvalue=np.uint(10), - maxshape=(None,) + maxshape=(None,), ) self.compressed_data = H5DataIO( data=np.arange(30).reshape(5, 2, 3), @@ -2303,7 +2676,7 @@ def setUpContainer(self): shuffle=True, fletcher32=True, allow_plugin_filters=True, - maxshape=(None, 2, 3) + maxshape=(None, 2, 3), ) self.compressed_index_data = H5DataIO( data=np.array([2, 4, 5], dtype=np.uint), @@ -2311,141 +2684,134 @@ def setUpContainer(self): shuffle=True, fletcher32=False, allow_plugin_filters=True, - maxshape=(None,) + maxshape=(None,), ) - foo = VectorData(name='foo', description='chunked column', data=self.chunked_data) - foo_ind = VectorIndex(name='foo_index', target=foo, data=self.chunked_index_data) - bar = VectorData(name='bar', description='chunked column', data=self.compressed_data) - bar_ind = VectorIndex(name='bar_index', target=bar, data=self.compressed_index_data) + foo = VectorData(name="foo", description="chunked column", data=self.chunked_data) + foo_ind = VectorIndex(name="foo_index", target=foo, data=self.chunked_index_data) + bar = VectorData(name="bar", description="chunked column", data=self.compressed_data) + bar_ind = VectorIndex(name="bar_index", target=bar, data=self.compressed_index_data) # NOTE: on construct, columns are ordered such that indices go before data, so create the table that way # for proper comparison of the columns list - table = DynamicTable(name='table0', description='an example table', columns=[foo_ind, foo, bar_ind, bar], - id=H5DataIO(data=[0, 1, 2], chunks=True, maxshape=(None,))) + table = DynamicTable( + name="table0", + description="an example table", + columns=[foo_ind, foo, bar_ind, bar], + id=H5DataIO(data=[0, 1, 2], chunks=True, maxshape=(None,)), + ) # check for add_row - table.add_row(foo=np.arange(30).reshape(5, 2, 3), - bar=np.arange(30).reshape(5, 2, 3)) + table.add_row( + foo=np.arange(30).reshape(5, 2, 3), + bar=np.arange(30).reshape(5, 2, 3), + ) return table def test_append(self, cache_spec=False): """Write the container to an HDF5 file, read the container from the file, and append to it.""" - with HDF5IO(self.filename, manager=get_manager(), mode='w') as write_io: + with HDF5IO(self.filename, manager=get_manager(), mode="w") as write_io: write_io.write(self.container, cache_spec=cache_spec) - self.reader = HDF5IO(self.filename, manager=get_manager(), mode='a') + self.reader = HDF5IO(self.filename, manager=get_manager(), mode="a") read_table = self.reader.read() data = np.arange(30, 60).reshape(5, 2, 3) read_table.add_row(foo=data, bar=data) - np.testing.assert_array_equal(read_table['foo'][-1], data) + np.testing.assert_array_equal(read_table["foo"][-1], data) class TestDTRReferences(TestCase): - def setUp(self): - self.filename = 'test_dtr_references.h5' + self.filename = "test_dtr_references.h5" def tearDown(self): remove_test_file(self.filename) def test_dtr_references(self): """Test roundtrip of a table with a ragged DTR to another table containing a column of references.""" - group1 = Container('group1') - group2 = Container('group2') + group1 = Container("group1") + group2 = Container("group2") - table1 = DynamicTable( - name='table1', - description='test table 1' - ) - table1.add_column( - name='x', - description='test column of ints' - ) - table1.add_column( - name='y', - description='test column of reference' - ) + table1 = DynamicTable(name="table1", description="test table 1") + table1.add_column(name="x", description="test column of ints") + table1.add_column(name="y", description="test column of reference") table1.add_row(id=101, x=1, y=group1) table1.add_row(id=102, x=2, y=group1) table1.add_row(id=103, x=3, y=group2) - table2 = DynamicTable( - name='table2', - description='test table 2' - ) + table2 = DynamicTable(name="table2", description="test table 2") # create a ragged column that references table1 # each row of table2 corresponds to one or more rows of table 1 table2.add_column( - name='electrodes', - description='column description', + name="electrodes", + description="column description", index=True, - table=table1 + table=table1, ) table2.add_row(id=10, electrodes=[1, 2]) - multi_container = SimpleMultiContainer(name='multi') + multi_container = SimpleMultiContainer(name="multi") multi_container.add_container(group1) multi_container.add_container(group2) multi_container.add_container(table1) multi_container.add_container(table2) - with HDF5IO(self.filename, manager=get_manager(), mode='w') as io: + with HDF5IO(self.filename, manager=get_manager(), mode="w") as io: io.write(multi_container) - with HDF5IO(self.filename, manager=get_manager(), mode='r') as io: + with HDF5IO(self.filename, manager=get_manager(), mode="r") as io: read_multi_container = io.read() self.assertContainerEqual(read_multi_container, multi_container, ignore_name=True) # test DTR access - read_group1 = read_multi_container['group1'] - read_group2 = read_multi_container['group2'] - read_table = read_multi_container['table2'] - ret = read_table[0, 'electrodes'] - expected = pd.DataFrame({'x': np.array([2, 3]), - 'y': [read_group1, read_group2]}, - index=pd.Index(data=[102, 103], name='id')) + read_group1 = read_multi_container["group1"] + read_group2 = read_multi_container["group2"] + read_table = read_multi_container["table2"] + ret = read_table[0, "electrodes"] + expected = pd.DataFrame( + {"x": np.array([2, 3]), "y": [read_group1, read_group2]}, + index=pd.Index(data=[102, 103], name="id"), + ) pd.testing.assert_frame_equal(ret, expected) class TestVectorIndexDtype(TestCase): - def set_up_array_index(self): - data = VectorData(name='data', description='desc') - index = VectorIndex(name='index', data=np.array([]), target=data) + data = VectorData(name="data", description="desc") + index = VectorIndex(name="index", data=np.array([]), target=data) return index def set_up_list_index(self): - data = VectorData(name='data', description='desc') - index = VectorIndex(name='index', data=[], target=data) + data = VectorData(name="data", description="desc") + index = VectorIndex(name="index", data=[], target=data) return index def test_array_inc_precision(self): index = self.set_up_array_index() - index.add_vector(np.empty((255, ))) + index.add_vector(np.empty((255,))) self.assertEqual(index.data[0], 255) self.assertEqual(index.data.dtype, np.uint8) def test_array_inc_precision_1step(self): index = self.set_up_array_index() - index.add_vector(np.empty((65535, ))) + index.add_vector(np.empty((65535,))) self.assertEqual(index.data[0], 65535) self.assertEqual(index.data.dtype, np.uint16) def test_array_inc_precision_2steps(self): index = self.set_up_array_index() - index.add_vector(np.empty((65536, ))) + index.add_vector(np.empty((65536,))) self.assertEqual(index.data[0], 65536) self.assertEqual(index.data.dtype, np.uint32) def test_array_prev_data_inc_precision_2steps(self): index = self.set_up_array_index() - index.add_vector(np.empty((255, ))) # dtype is still uint8 - index.add_vector(np.empty((65536, ))) + index.add_vector(np.empty((255,))) # dtype is still uint8 + index.add_vector(np.empty((65536,))) self.assertEqual(index.data[0], 255) # make sure the 255 is upgraded self.assertEqual(index.data.dtype, np.uint32) diff --git a/tests/unit/spec_tests/test_attribute_spec.py b/tests/unit/spec_tests/test_attribute_spec.py index 15102e728..c1522aeab 100644 --- a/tests/unit/spec_tests/test_attribute_spec.py +++ b/tests/unit/spec_tests/test_attribute_spec.py @@ -5,89 +5,92 @@ class AttributeSpecTests(TestCase): - def test_constructor(self): - spec = AttributeSpec('attribute1', - 'my first attribute', - 'text') - self.assertEqual(spec['name'], 'attribute1') - self.assertEqual(spec['dtype'], 'text') - self.assertEqual(spec['doc'], 'my first attribute') + spec = AttributeSpec("attribute1", "my first attribute", "text") + self.assertEqual(spec["name"], "attribute1") + self.assertEqual(spec["dtype"], "text") + self.assertEqual(spec["doc"], "my first attribute") self.assertIsNone(spec.parent) json.dumps(spec) # to ensure there are no circular links def test_invalid_dtype(self): with self.assertRaises(ValueError): - AttributeSpec(name='attribute1', - doc='my first attribute', - dtype='invalid' # <-- Invalid dtype must raise a ValueError - ) + AttributeSpec( + name="attribute1", + doc="my first attribute", + dtype="invalid", # <-- Invalid dtype must raise a ValueError + ) def test_both_value_and_default_value_set(self): with self.assertRaises(ValueError): - AttributeSpec(name='attribute1', - doc='my first attribute', - dtype='int', - value=5, - default_value=10 # <-- Default_value and value can't be set at the same time - ) + AttributeSpec( + name="attribute1", + doc="my first attribute", + dtype="int", + value=5, + default_value=10, # <-- Default_value and value can't be set at the same time + ) def test_colliding_shape_and_dims(self): with self.assertRaises(ValueError): - AttributeSpec(name='attribute1', - doc='my first attribute', - dtype='int', - dims=['test'], - shape=[None, 2] # <-- Length of shape and dims do not match must raise a ValueError - ) + AttributeSpec( + name="attribute1", + doc="my first attribute", + dtype="int", + dims=["test"], + shape=[ + None, + 2, + ], # <-- Length of shape and dims do not match must raise a ValueError + ) def test_default_value(self): - spec = AttributeSpec('attribute1', - 'my first attribute', - 'text', - default_value='some text') - self.assertEqual(spec['default_value'], 'some text') - self.assertEqual(spec.default_value, 'some text') + spec = AttributeSpec( + "attribute1", + "my first attribute", + "text", + default_value="some text", + ) + self.assertEqual(spec["default_value"], "some text") + self.assertEqual(spec.default_value, "some text") def test_shape(self): shape = [None, 2] - spec = AttributeSpec('attribute1', - 'my first attribute', - 'text', - shape=shape) - self.assertEqual(spec['shape'], shape) + spec = AttributeSpec("attribute1", "my first attribute", "text", shape=shape) + self.assertEqual(spec["shape"], shape) self.assertEqual(spec.shape, shape) def test_dims_without_shape(self): - spec = AttributeSpec('attribute1', - 'my first attribute', - 'text', - dims=['test']) - self.assertEqual(spec.shape, (None, )) + spec = AttributeSpec("attribute1", "my first attribute", "text", dims=["test"]) + self.assertEqual(spec.shape, (None,)) def test_build_spec(self): - spec_dict = {'name': 'attribute1', - 'doc': 'my first attribute', - 'dtype': 'text', - 'shape': [None], - 'dims': ['dim1'], - 'value': ['a', 'b']} + spec_dict = { + "name": "attribute1", + "doc": "my first attribute", + "dtype": "text", + "shape": [None], + "dims": ["dim1"], + "value": ["a", "b"], + } ret = AttributeSpec.build_spec(spec_dict) self.assertTrue(isinstance(ret, AttributeSpec)) self.assertDictEqual(ret, spec_dict) def test_build_spec_reftype(self): - spec_dict = {'name': 'attribute1', - 'doc': 'my first attribute', - 'dtype': {'target_type': 'AnotherType', 'reftype': 'object'}} + spec_dict = { + "name": "attribute1", + "doc": "my first attribute", + "dtype": {"target_type": "AnotherType", "reftype": "object"}, + } expected = spec_dict.copy() - expected['dtype'] = RefSpec(target_type='AnotherType', reftype='object') + expected["dtype"] = RefSpec(target_type="AnotherType", reftype="object") ret = AttributeSpec.build_spec(spec_dict) self.assertTrue(isinstance(ret, AttributeSpec)) self.assertDictEqual(ret, expected) def test_build_spec_no_doc(self): - spec_dict = {'name': 'attribute1', 'dtype': 'text'} + spec_dict = {"name": "attribute1", "dtype": "text"} msg = "AttributeSpec.__init__: missing argument 'doc'" with self.assertRaisesWith(TypeError, msg): AttributeSpec.build_spec(spec_dict) diff --git a/tests/unit/spec_tests/test_dataset_spec.py b/tests/unit/spec_tests/test_dataset_spec.py index 0309aced4..550910e3b 100644 --- a/tests/unit/spec_tests/test_dataset_spec.py +++ b/tests/unit/spec_tests/test_dataset_spec.py @@ -1,247 +1,290 @@ import json -from hdmf.spec import GroupSpec, DatasetSpec, AttributeSpec, DtypeSpec, RefSpec +from hdmf.spec import AttributeSpec, DatasetSpec, DtypeSpec, GroupSpec, RefSpec from hdmf.testing import TestCase class DatasetSpecTests(TestCase): def setUp(self): self.attributes = [ - AttributeSpec('attribute1', 'my first attribute', 'text'), - AttributeSpec('attribute2', 'my second attribute', 'text') + AttributeSpec("attribute1", "my first attribute", "text"), + AttributeSpec("attribute2", "my second attribute", "text"), ] def test_constructor(self): - spec = DatasetSpec('my first dataset', - 'int', - name='dataset1', - attributes=self.attributes) - self.assertEqual(spec['dtype'], 'int') - self.assertEqual(spec['name'], 'dataset1') - self.assertEqual(spec['doc'], 'my first dataset') - self.assertNotIn('linkable', spec) - self.assertNotIn('data_type_def', spec) - self.assertListEqual(spec['attributes'], self.attributes) + spec = DatasetSpec( + "my first dataset", + "int", + name="dataset1", + attributes=self.attributes, + ) + self.assertEqual(spec["dtype"], "int") + self.assertEqual(spec["name"], "dataset1") + self.assertEqual(spec["doc"], "my first dataset") + self.assertNotIn("linkable", spec) + self.assertNotIn("data_type_def", spec) + self.assertListEqual(spec["attributes"], self.attributes) self.assertIs(spec, self.attributes[0].parent) self.assertIs(spec, self.attributes[1].parent) json.dumps(spec) def test_constructor_datatype(self): - spec = DatasetSpec('my first dataset', - 'int', - name='dataset1', - attributes=self.attributes, - linkable=False, - data_type_def='EphysData') - self.assertEqual(spec['dtype'], 'int') - self.assertEqual(spec['name'], 'dataset1') - self.assertEqual(spec['doc'], 'my first dataset') - self.assertEqual(spec['data_type_def'], 'EphysData') - self.assertFalse(spec['linkable']) - self.assertListEqual(spec['attributes'], self.attributes) + spec = DatasetSpec( + "my first dataset", + "int", + name="dataset1", + attributes=self.attributes, + linkable=False, + data_type_def="EphysData", + ) + self.assertEqual(spec["dtype"], "int") + self.assertEqual(spec["name"], "dataset1") + self.assertEqual(spec["doc"], "my first dataset") + self.assertEqual(spec["data_type_def"], "EphysData") + self.assertFalse(spec["linkable"]) + self.assertListEqual(spec["attributes"], self.attributes) self.assertIs(spec, self.attributes[0].parent) self.assertIs(spec, self.attributes[1].parent) def test_constructor_shape(self): shape = [None, 2] - spec = DatasetSpec('my first dataset', - 'int', - name='dataset1', - shape=shape, - attributes=self.attributes) - self.assertEqual(spec['shape'], shape) + spec = DatasetSpec( + "my first dataset", + "int", + name="dataset1", + shape=shape, + attributes=self.attributes, + ) + self.assertEqual(spec["shape"], shape) self.assertEqual(spec.shape, shape) def test_constructor_invalidate_dtype(self): with self.assertRaises(ValueError): - DatasetSpec(doc='my first dataset', - dtype='my bad dtype', # <-- Expect AssertionError due to bad type - name='dataset1', - dims=(None, None), - attributes=self.attributes, - linkable=False, - data_type_def='EphysData') + DatasetSpec( + doc="my first dataset", + dtype="my bad dtype", # <-- Expect AssertionError due to bad type + name="dataset1", + dims=(None, None), + attributes=self.attributes, + linkable=False, + data_type_def="EphysData", + ) def test_constructor_ref_spec(self): - dtype = RefSpec('TimeSeries', 'object') - spec = DatasetSpec(doc='my first dataset', - dtype=dtype, - name='dataset1', - dims=(None, None), - attributes=self.attributes, - linkable=False, - data_type_def='EphysData') - self.assertDictEqual(spec['dtype'], dtype) + dtype = RefSpec("TimeSeries", "object") + spec = DatasetSpec( + doc="my first dataset", + dtype=dtype, + name="dataset1", + dims=(None, None), + attributes=self.attributes, + linkable=False, + data_type_def="EphysData", + ) + self.assertDictEqual(spec["dtype"], dtype) def test_datatype_extension(self): - base = DatasetSpec('my first dataset', - 'int', - name='dataset1', - attributes=self.attributes, - linkable=False, - data_type_def='EphysData') - - attributes = [AttributeSpec('attribute3', 'my first extending attribute', 'float')] - ext = DatasetSpec('my first dataset extension', - 'int', - name='dataset1', - attributes=attributes, - linkable=False, - data_type_inc=base, - data_type_def='SpikeData') - self.assertDictEqual(ext['attributes'][0], attributes[0]) - self.assertDictEqual(ext['attributes'][1], self.attributes[0]) - self.assertDictEqual(ext['attributes'][2], self.attributes[1]) + base = DatasetSpec( + "my first dataset", + "int", + name="dataset1", + attributes=self.attributes, + linkable=False, + data_type_def="EphysData", + ) + + attributes = [AttributeSpec("attribute3", "my first extending attribute", "float")] + ext = DatasetSpec( + "my first dataset extension", + "int", + name="dataset1", + attributes=attributes, + linkable=False, + data_type_inc=base, + data_type_def="SpikeData", + ) + self.assertDictEqual(ext["attributes"][0], attributes[0]) + self.assertDictEqual(ext["attributes"][1], self.attributes[0]) + self.assertDictEqual(ext["attributes"][2], self.attributes[1]) ext_attrs = ext.attributes self.assertIs(ext, ext_attrs[0].parent) self.assertIs(ext, ext_attrs[1].parent) self.assertIs(ext, ext_attrs[2].parent) def test_datatype_extension_groupspec(self): - '''Test to make sure DatasetSpec catches when a GroupSpec used as data_type_inc''' - base = GroupSpec('a fake grop', - data_type_def='EphysData') + """Test to make sure DatasetSpec catches when a GroupSpec used as data_type_inc""" + base = GroupSpec("a fake group", data_type_def="EphysData") with self.assertRaises(TypeError): - DatasetSpec('my first dataset extension', - 'int', - name='dataset1', - data_type_inc=base, - data_type_def='SpikeData') + DatasetSpec( + "my first dataset extension", + "int", + name="dataset1", + data_type_inc=base, + data_type_def="SpikeData", + ) def test_constructor_table(self): - dtype1 = DtypeSpec('column1', 'the first column', 'int') - dtype2 = DtypeSpec('column2', 'the second column', 'float') - spec = DatasetSpec('my first table', - [dtype1, dtype2], - name='table1', - attributes=self.attributes) - self.assertEqual(spec['dtype'], [dtype1, dtype2]) - self.assertEqual(spec['name'], 'table1') - self.assertEqual(spec['doc'], 'my first table') - self.assertNotIn('linkable', spec) - self.assertNotIn('data_type_def', spec) - self.assertListEqual(spec['attributes'], self.attributes) + dtype1 = DtypeSpec("column1", "the first column", "int") + dtype2 = DtypeSpec("column2", "the second column", "float") + spec = DatasetSpec( + "my first table", + [dtype1, dtype2], + name="table1", + attributes=self.attributes, + ) + self.assertEqual(spec["dtype"], [dtype1, dtype2]) + self.assertEqual(spec["name"], "table1") + self.assertEqual(spec["doc"], "my first table") + self.assertNotIn("linkable", spec) + self.assertNotIn("data_type_def", spec) + self.assertListEqual(spec["attributes"], self.attributes) self.assertIs(spec, self.attributes[0].parent) self.assertIs(spec, self.attributes[1].parent) json.dumps(spec) def test_constructor_invalid_table(self): with self.assertRaises(ValueError): - DatasetSpec('my first table', - [DtypeSpec('column1', 'the first column', 'int'), - {} # <--- Bad compound type spec must raise an error - ], - name='table1', - attributes=self.attributes) + DatasetSpec( + "my first table", + [ + DtypeSpec("column1", "the first column", "int"), + {}, # <--- Bad compound type spec must raise an error + ], + name="table1", + attributes=self.attributes, + ) def test_constructor_default_value(self): - spec = DatasetSpec(doc='test', - default_value=5, - dtype='int', - data_type_def='test') + spec = DatasetSpec(doc="test", default_value=5, dtype="int", data_type_def="test") self.assertEqual(spec.default_value, 5) def test_name_with_incompatible_quantity(self): # Check that we raise an error when the quantity allows more than one instance with a fixed name with self.assertRaises(ValueError): - DatasetSpec(doc='my first dataset', - dtype='int', - name='ds1', - quantity='zero_or_many') + DatasetSpec( + doc="my first dataset", + dtype="int", + name="ds1", + quantity="zero_or_many", + ) with self.assertRaises(ValueError): - DatasetSpec(doc='my first dataset', - dtype='int', - name='ds1', - quantity='one_or_many') + DatasetSpec( + doc="my first dataset", + dtype="int", + name="ds1", + quantity="one_or_many", + ) def test_name_with_compatible_quantity(self): # Make sure compatible quantity flags pass when name is fixed - DatasetSpec(doc='my first dataset', - dtype='int', - name='ds1', - quantity='zero_or_one') - DatasetSpec(doc='my first dataset', - dtype='int', - name='ds1', - quantity=1) + DatasetSpec( + doc="my first dataset", + dtype="int", + name="ds1", + quantity="zero_or_one", + ) + DatasetSpec(doc="my first dataset", dtype="int", name="ds1", quantity=1) def test_datatype_table_extension(self): - dtype1 = DtypeSpec('column1', 'the first column', 'int') - dtype2 = DtypeSpec('column2', 'the second column', 'float') - base = DatasetSpec('my first table', - [dtype1, dtype2], - attributes=self.attributes, - data_type_def='SimpleTable') - self.assertEqual(base['dtype'], [dtype1, dtype2]) - self.assertEqual(base['doc'], 'my first table') - dtype3 = DtypeSpec('column3', 'the third column', 'text') - ext = DatasetSpec('my first table extension', - [dtype3], - data_type_inc=base, - data_type_def='ExtendedTable') - self.assertEqual(ext['dtype'], [dtype1, dtype2, dtype3]) - self.assertEqual(ext['doc'], 'my first table extension') + dtype1 = DtypeSpec("column1", "the first column", "int") + dtype2 = DtypeSpec("column2", "the second column", "float") + base = DatasetSpec( + "my first table", + [dtype1, dtype2], + attributes=self.attributes, + data_type_def="SimpleTable", + ) + self.assertEqual(base["dtype"], [dtype1, dtype2]) + self.assertEqual(base["doc"], "my first table") + dtype3 = DtypeSpec("column3", "the third column", "text") + ext = DatasetSpec( + "my first table extension", + [dtype3], + data_type_inc=base, + data_type_def="ExtendedTable", + ) + self.assertEqual(ext["dtype"], [dtype1, dtype2, dtype3]) + self.assertEqual(ext["doc"], "my first table extension") def test_datatype_table_extension_higher_precision(self): - dtype1 = DtypeSpec('column1', 'the first column', 'int') - dtype2 = DtypeSpec('column2', 'the second column', 'float32') - base = DatasetSpec('my first table', - [dtype1, dtype2], - attributes=self.attributes, - data_type_def='SimpleTable') - self.assertEqual(base['dtype'], [dtype1, dtype2]) - self.assertEqual(base['doc'], 'my first table') - dtype3 = DtypeSpec('column2', 'the second column, with greater precision', 'float64') - ext = DatasetSpec('my first table extension', - [dtype3], - data_type_inc=base, - data_type_def='ExtendedTable') - self.assertEqual(ext['dtype'], [dtype1, dtype3]) - self.assertEqual(ext['doc'], 'my first table extension') + dtype1 = DtypeSpec("column1", "the first column", "int") + dtype2 = DtypeSpec("column2", "the second column", "float32") + base = DatasetSpec( + "my first table", + [dtype1, dtype2], + attributes=self.attributes, + data_type_def="SimpleTable", + ) + self.assertEqual(base["dtype"], [dtype1, dtype2]) + self.assertEqual(base["doc"], "my first table") + dtype3 = DtypeSpec("column2", "the second column, with greater precision", "float64") + ext = DatasetSpec( + "my first table extension", + [dtype3], + data_type_inc=base, + data_type_def="ExtendedTable", + ) + self.assertEqual(ext["dtype"], [dtype1, dtype3]) + self.assertEqual(ext["doc"], "my first table extension") def test_datatype_table_extension_lower_precision(self): - dtype1 = DtypeSpec('column1', 'the first column', 'int') - dtype2 = DtypeSpec('column2', 'the second column', 'float64') - base = DatasetSpec('my first table', - [dtype1, dtype2], - attributes=self.attributes, - data_type_def='SimpleTable') - self.assertEqual(base['dtype'], [dtype1, dtype2]) - self.assertEqual(base['doc'], 'my first table') - dtype3 = DtypeSpec('column2', 'the second column, with greater precision', 'float32') - with self.assertRaisesWith(ValueError, 'Cannot extend float64 to float32'): - DatasetSpec('my first table extension', - [dtype3], - data_type_inc=base, - data_type_def='ExtendedTable') + dtype1 = DtypeSpec("column1", "the first column", "int") + dtype2 = DtypeSpec("column2", "the second column", "float64") + base = DatasetSpec( + "my first table", + [dtype1, dtype2], + attributes=self.attributes, + data_type_def="SimpleTable", + ) + self.assertEqual(base["dtype"], [dtype1, dtype2]) + self.assertEqual(base["doc"], "my first table") + dtype3 = DtypeSpec("column2", "the second column, with greater precision", "float32") + with self.assertRaisesWith(ValueError, "Cannot extend float64 to float32"): + DatasetSpec( + "my first table extension", + [dtype3], + data_type_inc=base, + data_type_def="ExtendedTable", + ) def test_datatype_table_extension_diff_format(self): - dtype1 = DtypeSpec('column1', 'the first column', 'int') - dtype2 = DtypeSpec('column2', 'the second column', 'float64') - base = DatasetSpec('my first table', - [dtype1, dtype2], - attributes=self.attributes, - data_type_def='SimpleTable') - self.assertEqual(base['dtype'], [dtype1, dtype2]) - self.assertEqual(base['doc'], 'my first table') - dtype3 = DtypeSpec('column2', 'the second column, with greater precision', 'int32') - with self.assertRaisesWith(ValueError, 'Cannot extend float64 to int32'): - DatasetSpec('my first table extension', - [dtype3], - data_type_inc=base, - data_type_def='ExtendedTable') + dtype1 = DtypeSpec("column1", "the first column", "int") + dtype2 = DtypeSpec("column2", "the second column", "float64") + base = DatasetSpec( + "my first table", + [dtype1, dtype2], + attributes=self.attributes, + data_type_def="SimpleTable", + ) + self.assertEqual(base["dtype"], [dtype1, dtype2]) + self.assertEqual(base["doc"], "my first table") + dtype3 = DtypeSpec("column2", "the second column, with greater precision", "int32") + with self.assertRaisesWith(ValueError, "Cannot extend float64 to int32"): + DatasetSpec( + "my first table extension", + [dtype3], + data_type_inc=base, + data_type_def="ExtendedTable", + ) def test_data_type_property_value(self): """Test that the property data_type has the expected value""" test_cases = { - ('Foo', 'Bar'): 'Bar', - ('Foo', None): 'Foo', - (None, 'Bar'): 'Bar', + ("Foo", "Bar"): "Bar", + ("Foo", None): "Foo", + (None, "Bar"): "Bar", (None, None): None, } for (data_type_inc, data_type_def), data_type in test_cases.items(): - with self.subTest(data_type_inc=data_type_inc, - data_type_def=data_type_def, data_type=data_type): - group = GroupSpec('A group', name='group', - data_type_inc=data_type_inc, data_type_def=data_type_def) + with self.subTest( + data_type_inc=data_type_inc, + data_type_def=data_type_def, + data_type=data_type, + ): + group = GroupSpec( + "A group", + name="group", + data_type_inc=data_type_inc, + data_type_def=data_type_def, + ) self.assertEqual(group.data_type, data_type) diff --git a/tests/unit/spec_tests/test_dtype_spec.py b/tests/unit/spec_tests/test_dtype_spec.py index 946bbb9b7..97dabaf3b 100644 --- a/tests/unit/spec_tests/test_dtype_spec.py +++ b/tests/unit/spec_tests/test_dtype_spec.py @@ -1,4 +1,4 @@ -from hdmf.spec import DtypeSpec, DtypeHelper, RefSpec +from hdmf.spec import DtypeHelper, DtypeSpec, RefSpec from hdmf.testing import TestCase @@ -7,38 +7,50 @@ def setUp(self): pass def test_recommended_dtypes(self): - self.assertListEqual(DtypeHelper.recommended_primary_dtypes, - list(DtypeHelper.primary_dtype_synonyms.keys())) + self.assertListEqual( + DtypeHelper.recommended_primary_dtypes, + list(DtypeHelper.primary_dtype_synonyms.keys()), + ) def test_valid_primary_dtypes(self): - a = set(list(DtypeHelper.primary_dtype_synonyms.keys()) + - [vi for v in DtypeHelper.primary_dtype_synonyms.values() for vi in v]) + a = set( + list(DtypeHelper.primary_dtype_synonyms.keys()) + + [vi for v in DtypeHelper.primary_dtype_synonyms.values() for vi in v] + ) self.assertSetEqual(a, DtypeHelper.valid_primary_dtypes) def test_simplify_cpd_type(self): - compound_type = [DtypeSpec('test', 'test field', 'float'), - DtypeSpec('test2', 'test field2', 'int')] - expected_result = ['float', 'int'] + compound_type = [ + DtypeSpec("test", "test field", "float"), + DtypeSpec("test2", "test field2", "int"), + ] + expected_result = ["float", "int"] result = DtypeHelper.simplify_cpd_type(compound_type) self.assertListEqual(result, expected_result) def test_simplify_cpd_type_ref(self): - compound_type = [DtypeSpec('test', 'test field', 'float'), - DtypeSpec('test2', 'test field2', RefSpec(target_type='MyType', reftype='object'))] - expected_result = ['float', 'object'] + compound_type = [ + DtypeSpec("test", "test field", "float"), + DtypeSpec( + "test2", + "test field2", + RefSpec(target_type="MyType", reftype="object"), + ), + ] + expected_result = ["float", "object"] result = DtypeHelper.simplify_cpd_type(compound_type) self.assertListEqual(result, expected_result) def test_check_dtype_ok(self): - self.assertEqual('int', DtypeHelper.check_dtype('int')) + self.assertEqual("int", DtypeHelper.check_dtype("int")) def test_check_dtype_bad(self): msg = "dtype 'bad dtype' is not a valid primary data type." with self.assertRaisesRegex(ValueError, msg): - DtypeHelper.check_dtype('bad dtype') + DtypeHelper.check_dtype("bad dtype") def test_check_dtype_ref(self): - refspec = RefSpec(target_type='target', reftype='object') + refspec = RefSpec(target_type="target", reftype="object") self.assertIs(refspec, DtypeHelper.check_dtype(refspec)) @@ -47,34 +59,34 @@ def setUp(self): pass def test_constructor(self): - spec = DtypeSpec('column1', 'an example column', 'int') - self.assertEqual(spec.doc, 'an example column') - self.assertEqual(spec.name, 'column1') - self.assertEqual(spec.dtype, 'int') + spec = DtypeSpec("column1", "an example column", "int") + self.assertEqual(spec.doc, "an example column") + self.assertEqual(spec.name, "column1") + self.assertEqual(spec.dtype, "int") def test_build_spec(self): - spec = DtypeSpec.build_spec({'doc': 'an example column', 'name': 'column1', 'dtype': 'int'}) - self.assertEqual(spec.doc, 'an example column') - self.assertEqual(spec.name, 'column1') - self.assertEqual(spec.dtype, 'int') + spec = DtypeSpec.build_spec({"doc": "an example column", "name": "column1", "dtype": "int"}) + self.assertEqual(spec.doc, "an example column") + self.assertEqual(spec.name, "column1") + self.assertEqual(spec.dtype, "int") def test_invalid_refspec_dict(self): """Test missing or bad target key for RefSpec.""" msg = "'dtype' must have the key 'target_type'" with self.assertRaisesWith(ValueError, msg): - DtypeSpec.assertValidDtype({'no target': 'test', 'reftype': 'object'}) + DtypeSpec.assertValidDtype({"no target": "test", "reftype": "object"}) def test_refspec_dtype(self): # just making sure this does not cause an error - DtypeSpec('column1', 'an example column', RefSpec('TimeSeries', 'object')) + DtypeSpec("column1", "an example column", RefSpec("TimeSeries", "object")) def test_invalid_dtype(self): msg = "dtype 'bad dtype' is not a valid primary data type." with self.assertRaisesRegex(ValueError, msg): - DtypeSpec('column1', 'an example column', dtype='bad dtype') + DtypeSpec("column1", "an example column", dtype="bad dtype") def test_is_ref(self): - spec = DtypeSpec('column1', 'an example column', RefSpec('TimeSeries', 'object')) + spec = DtypeSpec("column1", "an example column", RefSpec("TimeSeries", "object")) self.assertTrue(DtypeSpec.is_ref(spec)) - spec = DtypeSpec('column1', 'an example column', 'int') + spec = DtypeSpec("column1", "an example column", "int") self.assertFalse(DtypeSpec.is_ref(spec)) diff --git a/tests/unit/spec_tests/test_group_spec.py b/tests/unit/spec_tests/test_group_spec.py index 9c117fa1f..e03eddfcd 100644 --- a/tests/unit/spec_tests/test_group_spec.py +++ b/tests/unit/spec_tests/test_group_spec.py @@ -1,60 +1,62 @@ import json -from hdmf.spec import GroupSpec, DatasetSpec, AttributeSpec, LinkSpec +from hdmf.spec import AttributeSpec, DatasetSpec, GroupSpec, LinkSpec from hdmf.testing import TestCase class GroupSpecTests(TestCase): def setUp(self): self.attributes = [ - AttributeSpec('attribute1', 'my first attribute', 'text'), - AttributeSpec('attribute2', 'my second attribute', 'text') + AttributeSpec("attribute1", "my first attribute", "text"), + AttributeSpec("attribute2", "my second attribute", "text"), ] self.dset1_attributes = [ - AttributeSpec('attribute3', 'my third attribute', 'text'), - AttributeSpec('attribute4', 'my fourth attribute', 'text') + AttributeSpec("attribute3", "my third attribute", "text"), + AttributeSpec("attribute4", "my fourth attribute", "text"), ] self.dset2_attributes = [ - AttributeSpec('attribute5', 'my fifth attribute', 'text'), - AttributeSpec('attribute6', 'my sixth attribute', 'text') + AttributeSpec("attribute5", "my fifth attribute", "text"), + AttributeSpec("attribute6", "my sixth attribute", "text"), ] self.datasets = [ - DatasetSpec('my first dataset', - 'int', - name='dataset1', - attributes=self.dset1_attributes, - linkable=True), - DatasetSpec('my second dataset', - 'int', - name='dataset2', - attributes=self.dset2_attributes, - linkable=True, - data_type_def='VoltageArray') + DatasetSpec( + "my first dataset", + "int", + name="dataset1", + attributes=self.dset1_attributes, + linkable=True, + ), + DatasetSpec( + "my second dataset", + "int", + name="dataset2", + attributes=self.dset2_attributes, + linkable=True, + data_type_def="VoltageArray", + ), ] self.subgroups = [ - GroupSpec('A test subgroup', - name='subgroup1', - linkable=False), - GroupSpec('A test subgroup', - name='subgroup2', - linkable=False) + GroupSpec("A test subgroup", name="subgroup1", linkable=False), + GroupSpec("A test subgroup", name="subgroup2", linkable=False), ] def test_constructor(self): - spec = GroupSpec('A test group', - name='root_constructor', - groups=self.subgroups, - datasets=self.datasets, - attributes=self.attributes, - linkable=False) - self.assertFalse(spec['linkable']) - self.assertListEqual(spec['attributes'], self.attributes) - self.assertListEqual(spec['datasets'], self.datasets) - self.assertNotIn('data_type_def', spec) + spec = GroupSpec( + "A test group", + name="root_constructor", + groups=self.subgroups, + datasets=self.datasets, + attributes=self.attributes, + linkable=False, + ) + self.assertFalse(spec["linkable"]) + self.assertListEqual(spec["attributes"], self.attributes) + self.assertListEqual(spec["datasets"], self.datasets) + self.assertNotIn("data_type_def", spec) self.assertIs(spec, self.subgroups[0].parent) self.assertIs(spec, self.subgroups[1].parent) self.assertIs(spec, self.attributes[0].parent) @@ -64,125 +66,123 @@ def test_constructor(self): json.dumps(spec) def test_constructor_datatype(self): - spec = GroupSpec('A test group', - name='root_constructor_datatype', - datasets=self.datasets, - attributes=self.attributes, - linkable=False, - data_type_def='EphysData') - self.assertFalse(spec['linkable']) - self.assertListEqual(spec['attributes'], self.attributes) - self.assertListEqual(spec['datasets'], self.datasets) - self.assertEqual(spec['data_type_def'], 'EphysData') + spec = GroupSpec( + "A test group", + name="root_constructor_datatype", + datasets=self.datasets, + attributes=self.attributes, + linkable=False, + data_type_def="EphysData", + ) + self.assertFalse(spec["linkable"]) + self.assertListEqual(spec["attributes"], self.attributes) + self.assertListEqual(spec["datasets"], self.datasets) + self.assertEqual(spec["data_type_def"], "EphysData") self.assertIs(spec, self.attributes[0].parent) self.assertIs(spec, self.attributes[1].parent) self.assertIs(spec, self.datasets[0].parent) self.assertIs(spec, self.datasets[1].parent) - self.assertEqual(spec.data_type_def, 'EphysData') + self.assertEqual(spec.data_type_def, "EphysData") self.assertIsNone(spec.data_type_inc) json.dumps(spec) def test_set_parent_exists(self): - GroupSpec('A test group', - name='root_constructor', - groups=self.subgroups) - msg = 'Cannot re-assign parent.' + GroupSpec("A test group", name="root_constructor", groups=self.subgroups) + msg = "Cannot re-assign parent." with self.assertRaisesWith(AttributeError, msg): self.subgroups[0].parent = self.subgroups[1] def test_set_dataset(self): - spec = GroupSpec('A test group', - name='root_test_set_dataset', - linkable=False, - data_type_def='EphysData') + spec = GroupSpec( + "A test group", + name="root_test_set_dataset", + linkable=False, + data_type_def="EphysData", + ) spec.set_dataset(self.datasets[0]) self.assertIs(spec, self.datasets[0].parent) def test_set_link(self): - group = GroupSpec( - doc='A test group', - name='root' - ) - link = LinkSpec( - doc='A test link', - target_type='LinkTarget', - name='link_name' - ) + group = GroupSpec(doc="A test group", name="root") + link = LinkSpec(doc="A test link", target_type="LinkTarget", name="link_name") group.set_link(link) self.assertIs(group, link.parent) - self.assertIs(group.get_link('link_name'), link) + self.assertIs(group.get_link("link_name"), link) def test_add_link(self): - group = GroupSpec( - doc='A test group', - name='root' - ) - group.add_link( - 'A test link', - 'LinkTarget', - name='link_name' - ) - self.assertIsInstance(group.get_link('link_name'), LinkSpec) + group = GroupSpec(doc="A test group", name="root") + group.add_link("A test link", "LinkTarget", name="link_name") + self.assertIsInstance(group.get_link("link_name"), LinkSpec) def test_set_group(self): - spec = GroupSpec('A test group', - name='root_test_set_group', - linkable=False, - data_type_def='EphysData') + spec = GroupSpec( + "A test group", + name="root_test_set_group", + linkable=False, + data_type_def="EphysData", + ) spec.set_group(self.subgroups[0]) spec.set_group(self.subgroups[1]) - self.assertListEqual(spec['groups'], self.subgroups) + self.assertListEqual(spec["groups"], self.subgroups) self.assertIs(spec, self.subgroups[0].parent) self.assertIs(spec, self.subgroups[1].parent) json.dumps(spec) def test_add_group(self): - group = GroupSpec( - doc='A test group', - name='root' - ) - group.add_group( - 'A test group', - name='subgroup' - ) - self.assertIsInstance(group.get_group('subgroup'), GroupSpec) + group = GroupSpec(doc="A test group", name="root") + group.add_group("A test group", name="subgroup") + self.assertIsInstance(group.get_group("subgroup"), GroupSpec) def test_type_extension(self): - spec = GroupSpec('A test group', - name='parent_type', - datasets=self.datasets, - attributes=self.attributes, - linkable=False, - data_type_def='EphysData') + spec = GroupSpec( + "A test group", + name="parent_type", + datasets=self.datasets, + attributes=self.attributes, + linkable=False, + data_type_def="EphysData", + ) dset1_attributes_ext = [ - AttributeSpec('dset1_extra_attribute', 'an extra attribute for the first dataset', 'text') + AttributeSpec( + "dset1_extra_attribute", + "an extra attribute for the first dataset", + "text", + ) ] ext_datasets = [ - DatasetSpec('my first dataset extension', - 'int', - name='dataset1', - attributes=dset1_attributes_ext, - linkable=True), + DatasetSpec( + "my first dataset extension", + "int", + name="dataset1", + attributes=dset1_attributes_ext, + linkable=True, + ), ] ext_attributes = [ - AttributeSpec('ext_extra_attribute', 'an extra attribute for the group', 'text'), + AttributeSpec( + "ext_extra_attribute", + "an extra attribute for the group", + "text", + ), ] - ext = GroupSpec('A test group extension', - name='child_type', - datasets=ext_datasets, - attributes=ext_attributes, - linkable=False, - data_type_inc=spec, - data_type_def='SpikeData') - ext_dset1 = ext.get_dataset('dataset1') + ext = GroupSpec( + "A test group extension", + name="child_type", + datasets=ext_datasets, + attributes=ext_attributes, + linkable=False, + data_type_inc=spec, + data_type_def="SpikeData", + ) + ext_dset1 = ext.get_dataset("dataset1") ext_dset1_attrs = ext_dset1.attributes self.assertDictEqual(ext_dset1_attrs[0], dset1_attributes_ext[0]) self.assertDictEqual(ext_dset1_attrs[1], self.dset1_attributes[0]) self.assertDictEqual(ext_dset1_attrs[2], self.dset1_attributes[1]) - self.assertEqual(ext.data_type_def, 'SpikeData') - self.assertEqual(ext.data_type_inc, 'EphysData') + self.assertEqual(ext.data_type_def, "SpikeData") + self.assertEqual(ext.data_type_inc, "EphysData") - ext_dset2 = ext.get_dataset('dataset2') + ext_dset2 = ext.get_dataset("dataset2") self.maxDiff = None # this will suffice for now, assertDictEqual doesn't do deep equality checks self.assertEqual(str(ext_dset2), str(self.datasets[1])) @@ -205,7 +205,7 @@ def assertDatasetsEqual(self, spec1, spec2): spec1_dsets = spec1.datasets spec2_dsets = spec2.datasets if len(spec1_dsets) != len(spec2_dsets): - raise AssertionError('different number of AttributeSpecs') + raise AssertionError("different number of AttributeSpecs") else: for i in range(len(spec1_dsets)): self.assertAttributesEqual(spec1_dsets[i], spec2_dsets[i]) @@ -214,22 +214,24 @@ def assertAttributesEqual(self, spec1, spec2): spec1_attr = spec1.attributes spec2_attr = spec2.attributes if len(spec1_attr) != len(spec2_attr): - raise AssertionError('different number of AttributeSpecs') + raise AssertionError("different number of AttributeSpecs") else: for i in range(len(spec1_attr)): self.assertDictEqual(spec1_attr[i], spec2_attr[i]) def test_add_attribute(self): - spec = GroupSpec('A test group', - name='root_constructor', - groups=self.subgroups, - datasets=self.datasets, - linkable=False) + spec = GroupSpec( + "A test group", + name="root_constructor", + groups=self.subgroups, + datasets=self.datasets, + linkable=False, + ) for attrspec in self.attributes: spec.add_attribute(**attrspec) - self.assertListEqual(spec['attributes'], self.attributes) - self.assertListEqual(spec['datasets'], self.datasets) - self.assertNotIn('data_type_def', spec) + self.assertListEqual(spec["attributes"], self.attributes) + self.assertListEqual(spec["datasets"], self.datasets) + self.assertNotIn("data_type_def", spec) self.assertIs(spec, self.subgroups[0].parent) self.assertIs(spec, self.subgroups[1].parent) self.assertIs(spec, spec.attributes[0].parent) @@ -239,203 +241,221 @@ def test_add_attribute(self): json.dumps(spec) def test_update_attribute_spec(self): - spec = GroupSpec('A test group', - name='root_constructor', - attributes=[AttributeSpec('attribute1', 'my first attribute', 'text'), - AttributeSpec('attribute2', 'my second attribute', 'text')]) - spec.set_attribute(AttributeSpec('attribute2', 'my second attribute', 'int', value=5)) - res = spec.get_attribute('attribute2') + spec = GroupSpec( + "A test group", + name="root_constructor", + attributes=[ + AttributeSpec("attribute1", "my first attribute", "text"), + AttributeSpec("attribute2", "my second attribute", "text"), + ], + ) + spec.set_attribute(AttributeSpec("attribute2", "my second attribute", "int", value=5)) + res = spec.get_attribute("attribute2") self.assertEqual(res.value, 5) - self.assertEqual(res.dtype, 'int') + self.assertEqual(res.dtype, "int") def test_path(self): - GroupSpec('A test group', - name='root_constructor', - groups=self.subgroups, - datasets=self.datasets, - attributes=self.attributes, - linkable=False) - self.assertEqual(self.attributes[0].path, 'root_constructor/attribute1') - self.assertEqual(self.datasets[0].path, 'root_constructor/dataset1') - self.assertEqual(self.subgroups[0].path, 'root_constructor/subgroup1') + GroupSpec( + "A test group", + name="root_constructor", + groups=self.subgroups, + datasets=self.datasets, + attributes=self.attributes, + linkable=False, + ) + self.assertEqual(self.attributes[0].path, "root_constructor/attribute1") + self.assertEqual(self.datasets[0].path, "root_constructor/dataset1") + self.assertEqual(self.subgroups[0].path, "root_constructor/subgroup1") def test_path_complicated(self): - attribute = AttributeSpec('attribute1', 'my fifth attribute', 'text') - dataset = DatasetSpec('my first dataset', - 'int', - name='dataset1', - attributes=[attribute]) - subgroup = GroupSpec('A subgroup', - name='subgroup1', - datasets=[dataset]) - self.assertEqual(attribute.path, 'subgroup1/dataset1/attribute1') - - _ = GroupSpec('A test group', - name='root', - groups=[subgroup]) - - self.assertEqual(attribute.path, 'root/subgroup1/dataset1/attribute1') + attribute = AttributeSpec("attribute1", "my fifth attribute", "text") + dataset = DatasetSpec("my first dataset", "int", name="dataset1", attributes=[attribute]) + subgroup = GroupSpec("A subgroup", name="subgroup1", datasets=[dataset]) + self.assertEqual(attribute.path, "subgroup1/dataset1/attribute1") + + _ = GroupSpec("A test group", name="root", groups=[subgroup]) + + self.assertEqual(attribute.path, "root/subgroup1/dataset1/attribute1") def test_path_no_name(self): - attribute = AttributeSpec('attribute1', 'my fifth attribute', 'text') - dataset = DatasetSpec('my first dataset', - 'int', - data_type_inc='DatasetType', - attributes=[attribute]) - subgroup = GroupSpec('A subgroup', - data_type_def='GroupType', - datasets=[dataset]) - _ = GroupSpec('A test group', - name='root', - groups=[subgroup]) - - self.assertEqual(attribute.path, 'root/GroupType/DatasetType/attribute1') + attribute = AttributeSpec("attribute1", "my fifth attribute", "text") + dataset = DatasetSpec( + "my first dataset", + "int", + data_type_inc="DatasetType", + attributes=[attribute], + ) + subgroup = GroupSpec("A subgroup", data_type_def="GroupType", datasets=[dataset]) + _ = GroupSpec("A test group", name="root", groups=[subgroup]) + + self.assertEqual(attribute.path, "root/GroupType/DatasetType/attribute1") def test_data_type_property_value(self): """Test that the property data_type has the expected value""" test_cases = { - ('Foo', 'Bar'): 'Bar', - ('Foo', None): 'Foo', - (None, 'Bar'): 'Bar', + ("Foo", "Bar"): "Bar", + ("Foo", None): "Foo", + (None, "Bar"): "Bar", (None, None): None, } for (data_type_inc, data_type_def), data_type in test_cases.items(): - with self.subTest(data_type_inc=data_type_inc, - data_type_def=data_type_def, data_type=data_type): - dataset = DatasetSpec('A dataset', 'int', name='dataset', - data_type_inc=data_type_inc, data_type_def=data_type_def) + with self.subTest( + data_type_inc=data_type_inc, + data_type_def=data_type_def, + data_type=data_type, + ): + dataset = DatasetSpec( + "A dataset", + "int", + name="dataset", + data_type_inc=data_type_inc, + data_type_def=data_type_def, + ) self.assertEqual(dataset.data_type, data_type) def test_get_data_type_spec(self): - expected = AttributeSpec('data_type', 'the data type of this object', 'text', value='MyType') - self.assertDictEqual(GroupSpec.get_data_type_spec('MyType'), expected) + expected = AttributeSpec("data_type", "the data type of this object", "text", value="MyType") + self.assertDictEqual(GroupSpec.get_data_type_spec("MyType"), expected) def test_get_namespace_spec(self): - expected = AttributeSpec('namespace', 'the namespace for the data type of this object', 'text', required=False) + expected = AttributeSpec( + "namespace", + "the namespace for the data type of this object", + "text", + required=False, + ) self.assertDictEqual(GroupSpec.get_namespace_spec(), expected) class TestNotAllowedConfig(TestCase): - def test_no_name_no_def_no_inc(self): - msg = ("Cannot create Group or Dataset spec with no name without specifying 'data_type_def' " - "and/or 'data_type_inc'.") + msg = ( + "Cannot create Group or Dataset spec with no name without specifying" + " 'data_type_def' and/or 'data_type_inc'." + ) with self.assertRaisesWith(ValueError, msg): - GroupSpec('A test group') + GroupSpec("A test group") def test_name_with_multiple(self): - msg = ("Cannot give specific name to something that can exist multiple times: name='MyGroup', quantity='*'") + msg = "Cannot give specific name to something that can exist multiple times: name='MyGroup', quantity='*'" with self.assertRaisesWith(ValueError, msg): - GroupSpec('A test group', name='MyGroup', quantity='*') + GroupSpec("A test group", name="MyGroup", quantity="*") class TestResolveAttrs(TestCase): - def setUp(self): self.def_group_spec = GroupSpec( - doc='A test group', - name='root', - data_type_def='MyGroup', - attributes=[AttributeSpec('attribute1', 'my first attribute', 'text'), - AttributeSpec('attribute2', 'my second attribute', 'text')] + doc="A test group", + name="root", + data_type_def="MyGroup", + attributes=[ + AttributeSpec("attribute1", "my first attribute", "text"), + AttributeSpec("attribute2", "my second attribute", "text"), + ], ) self.inc_group_spec = GroupSpec( - doc='A test group', - name='root', - data_type_inc='MyGroup', - attributes=[AttributeSpec('attribute2', 'my second attribute', 'text', value='fixed'), - AttributeSpec('attribute3', 'my third attribute', 'text', value='fixed')] + doc="A test group", + name="root", + data_type_inc="MyGroup", + attributes=[ + AttributeSpec("attribute2", "my second attribute", "text", value="fixed"), + AttributeSpec("attribute3", "my third attribute", "text", value="fixed"), + ], ) self.inc_group_spec.resolve_spec(self.def_group_spec) def test_resolved(self): - self.assertTupleEqual(self.inc_group_spec.attributes, ( - AttributeSpec('attribute2', 'my second attribute', 'text', value='fixed'), - AttributeSpec('attribute3', 'my third attribute', 'text', value='fixed'), - AttributeSpec('attribute1', 'my first attribute', 'text') - )) - - self.assertEqual(self.inc_group_spec.get_attribute('attribute1'), - AttributeSpec('attribute1', 'my first attribute', 'text')) - self.assertEqual(self.inc_group_spec.get_attribute('attribute2'), - AttributeSpec('attribute2', 'my second attribute', 'text', value='fixed')) - self.assertEqual(self.inc_group_spec.get_attribute('attribute3'), - AttributeSpec('attribute3', 'my third attribute', 'text', value='fixed')) + self.assertTupleEqual( + self.inc_group_spec.attributes, + ( + AttributeSpec("attribute2", "my second attribute", "text", value="fixed"), + AttributeSpec("attribute3", "my third attribute", "text", value="fixed"), + AttributeSpec("attribute1", "my first attribute", "text"), + ), + ) + + self.assertEqual( + self.inc_group_spec.get_attribute("attribute1"), + AttributeSpec("attribute1", "my first attribute", "text"), + ) + self.assertEqual( + self.inc_group_spec.get_attribute("attribute2"), + AttributeSpec("attribute2", "my second attribute", "text", value="fixed"), + ) + self.assertEqual( + self.inc_group_spec.get_attribute("attribute3"), + AttributeSpec("attribute3", "my third attribute", "text", value="fixed"), + ) self.assertTrue(self.inc_group_spec.resolved) def test_is_inherited_spec(self): - self.assertFalse(self.def_group_spec.is_inherited_spec('attribute1')) - self.assertFalse(self.def_group_spec.is_inherited_spec('attribute2')) - self.assertTrue(self.inc_group_spec.is_inherited_spec( - AttributeSpec('attribute1', 'my first attribute', 'text') - )) - self.assertTrue(self.inc_group_spec.is_inherited_spec('attribute1')) - self.assertTrue(self.inc_group_spec.is_inherited_spec('attribute2')) - self.assertFalse(self.inc_group_spec.is_inherited_spec('attribute3')) - self.assertFalse(self.inc_group_spec.is_inherited_spec('attribute4')) + self.assertFalse(self.def_group_spec.is_inherited_spec("attribute1")) + self.assertFalse(self.def_group_spec.is_inherited_spec("attribute2")) + self.assertTrue( + self.inc_group_spec.is_inherited_spec(AttributeSpec("attribute1", "my first attribute", "text")) + ) + self.assertTrue(self.inc_group_spec.is_inherited_spec("attribute1")) + self.assertTrue(self.inc_group_spec.is_inherited_spec("attribute2")) + self.assertFalse(self.inc_group_spec.is_inherited_spec("attribute3")) + self.assertFalse(self.inc_group_spec.is_inherited_spec("attribute4")) def test_is_overridden_spec(self): - self.assertFalse(self.def_group_spec.is_overridden_spec('attribute1')) - self.assertFalse(self.def_group_spec.is_overridden_spec('attribute2')) - self.assertFalse(self.inc_group_spec.is_overridden_spec( - AttributeSpec('attribute1', 'my first attribute', 'text') - )) - self.assertFalse(self.inc_group_spec.is_overridden_spec('attribute1')) - self.assertTrue(self.inc_group_spec.is_overridden_spec('attribute2')) - self.assertFalse(self.inc_group_spec.is_overridden_spec('attribute3')) - self.assertFalse(self.inc_group_spec.is_overridden_spec('attribute4')) + self.assertFalse(self.def_group_spec.is_overridden_spec("attribute1")) + self.assertFalse(self.def_group_spec.is_overridden_spec("attribute2")) + self.assertFalse( + self.inc_group_spec.is_overridden_spec(AttributeSpec("attribute1", "my first attribute", "text")) + ) + self.assertFalse(self.inc_group_spec.is_overridden_spec("attribute1")) + self.assertTrue(self.inc_group_spec.is_overridden_spec("attribute2")) + self.assertFalse(self.inc_group_spec.is_overridden_spec("attribute3")) + self.assertFalse(self.inc_group_spec.is_overridden_spec("attribute4")) def test_is_inherited_attribute(self): - self.assertFalse(self.def_group_spec.is_inherited_attribute('attribute1')) - self.assertFalse(self.def_group_spec.is_inherited_attribute('attribute2')) - self.assertTrue(self.inc_group_spec.is_inherited_attribute('attribute1')) - self.assertTrue(self.inc_group_spec.is_inherited_attribute('attribute2')) - self.assertFalse(self.inc_group_spec.is_inherited_attribute('attribute3')) + self.assertFalse(self.def_group_spec.is_inherited_attribute("attribute1")) + self.assertFalse(self.def_group_spec.is_inherited_attribute("attribute2")) + self.assertTrue(self.inc_group_spec.is_inherited_attribute("attribute1")) + self.assertTrue(self.inc_group_spec.is_inherited_attribute("attribute2")) + self.assertFalse(self.inc_group_spec.is_inherited_attribute("attribute3")) with self.assertRaisesWith(ValueError, "Attribute 'attribute4' not found"): - self.inc_group_spec.is_inherited_attribute('attribute4') + self.inc_group_spec.is_inherited_attribute("attribute4") def test_is_overridden_attribute(self): - self.assertFalse(self.def_group_spec.is_overridden_attribute('attribute1')) - self.assertFalse(self.def_group_spec.is_overridden_attribute('attribute2')) - self.assertFalse(self.inc_group_spec.is_overridden_attribute('attribute1')) - self.assertTrue(self.inc_group_spec.is_overridden_attribute('attribute2')) - self.assertFalse(self.inc_group_spec.is_overridden_attribute('attribute3')) + self.assertFalse(self.def_group_spec.is_overridden_attribute("attribute1")) + self.assertFalse(self.def_group_spec.is_overridden_attribute("attribute2")) + self.assertFalse(self.inc_group_spec.is_overridden_attribute("attribute1")) + self.assertTrue(self.inc_group_spec.is_overridden_attribute("attribute2")) + self.assertFalse(self.inc_group_spec.is_overridden_attribute("attribute3")) with self.assertRaisesWith(ValueError, "Attribute 'attribute4' not found"): - self.inc_group_spec.is_overridden_attribute('attribute4') + self.inc_group_spec.is_overridden_attribute("attribute4") class GroupSpecWithLinksTest(TestCase): - def test_constructor(self): - link0 = LinkSpec(doc='Link 0', target_type='TargetType0') - link1 = LinkSpec(doc='Link 1', target_type='TargetType1') + link0 = LinkSpec(doc="Link 0", target_type="TargetType0") + link1 = LinkSpec(doc="Link 1", target_type="TargetType1") links = [link0, link1] - spec = GroupSpec( - doc='A test group', - name='root', - links=links - ) + spec = GroupSpec(doc="A test group", name="root", links=links) self.assertIs(spec, links[0].parent) self.assertIs(spec, links[1].parent) json.dumps(spec) def test_extension_no_overwrite(self): - link0 = LinkSpec(doc='Link 0', target_type='TargetType0') # test unnamed - link1 = LinkSpec(doc='Link 1', target_type='TargetType1', name='MyType1') # test named - link2 = LinkSpec(doc='Link 2', target_type='TargetType2', quantity='*') # test named, multiple + link0 = LinkSpec(doc="Link 0", target_type="TargetType0") # test unnamed + link1 = LinkSpec(doc="Link 1", target_type="TargetType1", name="MyType1") # test named + link2 = LinkSpec(doc="Link 2", target_type="TargetType2", quantity="*") # test named, multiple links = [link0, link1, link2] parent_spec = GroupSpec( - doc='A test group', - name='parent', + doc="A test group", + name="parent", links=links, - data_type_def='ParentType' + data_type_def="ParentType", ) child_spec = GroupSpec( - doc='A test group', - name='child', + doc="A test group", + name="child", data_type_inc=parent_spec, - data_type_def='ChildType' + data_type_def="ChildType", ) for link in links: @@ -444,27 +464,27 @@ def test_extension_no_overwrite(self): self.assertFalse(child_spec.is_overridden_spec(link)) def test_extension_overwrite(self): - link0 = LinkSpec(doc='Link 0', target_type='TargetType0', name='MyType0') - link1 = LinkSpec(doc='Link 1', target_type='TargetType1', name='MyType1') + link0 = LinkSpec(doc="Link 0", target_type="TargetType0", name="MyType0") + link1 = LinkSpec(doc="Link 1", target_type="TargetType1", name="MyType1") # NOTE overwriting unnamed LinkSpec is not allowed # NOTE overwriting spec with quantity that could be >1 is not allowed links = [link0, link1] parent_spec = GroupSpec( - doc='A test group', - name='parent', + doc="A test group", + name="parent", links=links, - data_type_def='ParentType' + data_type_def="ParentType", ) - link0_overwrite = LinkSpec(doc='New link 0', target_type='TargetType0', name='MyType0') - link1_overwrite = LinkSpec(doc='New link 1', target_type='TargetType1Child', name='MyType1') + link0_overwrite = LinkSpec(doc="New link 0", target_type="TargetType0", name="MyType0") + link1_overwrite = LinkSpec(doc="New link 1", target_type="TargetType1Child", name="MyType1") overwritten_links = [link0_overwrite, link1_overwrite] child_spec = GroupSpec( - doc='A test group', - name='child', + doc="A test group", + name="child", links=overwritten_links, data_type_inc=parent_spec, - data_type_def='ChildType' + data_type_def="ChildType", ) for link in overwritten_links: @@ -474,256 +494,254 @@ def test_extension_overwrite(self): class SpecWithDupsTest(TestCase): - def test_two_unnamed_group_same_type(self): """Test creating a group contains multiple unnamed groups with type X.""" - child0 = GroupSpec(doc='Group 0', data_type_inc='Type0') - child1 = GroupSpec(doc='Group 1', data_type_inc='Type0') + child0 = GroupSpec(doc="Group 0", data_type_inc="Type0") + child1 = GroupSpec(doc="Group 1", data_type_inc="Type0") msg = "Cannot have multiple groups/datasets with the same data type without specifying name" with self.assertRaisesWith(ValueError, msg): GroupSpec( - doc='A test group', - name='parent', + doc="A test group", + name="parent", groups=[child0, child1], - data_type_def='ParentType' + data_type_def="ParentType", ) def test_named_unnamed_group_with_def_same_type(self): """Test get_data_type when a group contains both a named and unnamed group with type X.""" - child0 = GroupSpec(doc='Group 0', data_type_def='Type0', name='type0') - child1 = GroupSpec(doc='Group 1', data_type_inc='Type0') + child0 = GroupSpec(doc="Group 0", data_type_def="Type0", name="type0") + child1 = GroupSpec(doc="Group 1", data_type_inc="Type0") parent_spec = GroupSpec( - doc='A test group', - name='parent', + doc="A test group", + name="parent", groups=[child0, child1], - data_type_def='ParentType' + data_type_def="ParentType", ) - self.assertIs(parent_spec.get_data_type('Type0'), child1) + self.assertIs(parent_spec.get_data_type("Type0"), child1) def test_named_unnamed_group_same_type(self): """Test get_data_type when a group contains both a named and unnamed group with type X.""" - child0 = GroupSpec(doc='Group 0', data_type_inc='Type0', name='type0') - child1 = GroupSpec(doc='Group 1', data_type_inc='Type0', name='type1') - child2 = GroupSpec(doc='Group 2', data_type_inc='Type0') + child0 = GroupSpec(doc="Group 0", data_type_inc="Type0", name="type0") + child1 = GroupSpec(doc="Group 1", data_type_inc="Type0", name="type1") + child2 = GroupSpec(doc="Group 2", data_type_inc="Type0") parent_spec = GroupSpec( - doc='A test group', - name='parent', + doc="A test group", + name="parent", groups=[child0, child1, child2], - data_type_def='ParentType' + data_type_def="ParentType", ) - self.assertIs(parent_spec.get_data_type('Type0'), child2) + self.assertIs(parent_spec.get_data_type("Type0"), child2) def test_unnamed_named_group_same_type(self): """Test get_data_type when a group contains both an unnamed and named group with type X.""" - child0 = GroupSpec(doc='Group 0', data_type_inc='Type0') - child1 = GroupSpec(doc='Group 1', data_type_inc='Type0', name='type1') + child0 = GroupSpec(doc="Group 0", data_type_inc="Type0") + child1 = GroupSpec(doc="Group 1", data_type_inc="Type0", name="type1") parent_spec = GroupSpec( - doc='A test group', - name='parent', + doc="A test group", + name="parent", groups=[child0, child1], - data_type_def='ParentType' + data_type_def="ParentType", ) - self.assertIs(parent_spec.get_data_type('Type0'), child0) + self.assertIs(parent_spec.get_data_type("Type0"), child0) def test_two_named_group_same_type(self): """Test get_data_type when a group contains multiple named groups with type X.""" - child0 = GroupSpec(doc='Group 0', data_type_inc='Type0', name='group0') - child1 = GroupSpec(doc='Group 1', data_type_inc='Type0', name='group1') + child0 = GroupSpec(doc="Group 0", data_type_inc="Type0", name="group0") + child1 = GroupSpec(doc="Group 1", data_type_inc="Type0", name="group1") parent_spec = GroupSpec( - doc='A test group', - name='parent', + doc="A test group", + name="parent", groups=[child0, child1], - data_type_def='ParentType' + data_type_def="ParentType", ) - self.assertEqual(parent_spec.get_data_type('Type0'), [child0, child1]) + self.assertEqual(parent_spec.get_data_type("Type0"), [child0, child1]) def test_two_unnamed_datasets_same_type(self): """Test creating a group contains multiple unnamed datasets with type X.""" - child0 = DatasetSpec(doc='Group 0', data_type_inc='Type0') - child1 = DatasetSpec(doc='Group 1', data_type_inc='Type0') + child0 = DatasetSpec(doc="Group 0", data_type_inc="Type0") + child1 = DatasetSpec(doc="Group 1", data_type_inc="Type0") msg = "Cannot have multiple groups/datasets with the same data type without specifying name" with self.assertRaisesWith(ValueError, msg): GroupSpec( - doc='A test group', - name='parent', + doc="A test group", + name="parent", datasets=[child0, child1], - data_type_def='ParentType' + data_type_def="ParentType", ) def test_named_unnamed_dataset_with_def_same_type(self): """Test get_data_type when a group contains both a named and unnamed dataset with type X.""" - child0 = DatasetSpec(doc='Group 0', data_type_def='Type0', name='type0') - child1 = DatasetSpec(doc='Group 1', data_type_inc='Type0') + child0 = DatasetSpec(doc="Group 0", data_type_def="Type0", name="type0") + child1 = DatasetSpec(doc="Group 1", data_type_inc="Type0") parent_spec = GroupSpec( - doc='A test group', - name='parent', + doc="A test group", + name="parent", datasets=[child0, child1], - data_type_def='ParentType' + data_type_def="ParentType", ) - self.assertIs(parent_spec.get_data_type('Type0'), child1) + self.assertIs(parent_spec.get_data_type("Type0"), child1) def test_named_unnamed_dataset_same_type(self): """Test get_data_type when a group contains both a named and unnamed dataset with type X.""" - child0 = DatasetSpec(doc='Group 0', data_type_inc='Type0', name='type0') - child1 = DatasetSpec(doc='Group 1', data_type_inc='Type0') + child0 = DatasetSpec(doc="Group 0", data_type_inc="Type0", name="type0") + child1 = DatasetSpec(doc="Group 1", data_type_inc="Type0") parent_spec = GroupSpec( - doc='A test group', - name='parent', + doc="A test group", + name="parent", datasets=[child0, child1], - data_type_def='ParentType' + data_type_def="ParentType", ) - self.assertIs(parent_spec.get_data_type('Type0'), child1) + self.assertIs(parent_spec.get_data_type("Type0"), child1) def test_two_named_unnamed_dataset_same_type(self): """Test get_data_type when a group contains two named and one unnamed dataset with type X.""" - child0 = DatasetSpec(doc='Group 0', data_type_inc='Type0', name='type0') - child1 = DatasetSpec(doc='Group 1', data_type_inc='Type0', name='type1') - child2 = DatasetSpec(doc='Group 2', data_type_inc='Type0') + child0 = DatasetSpec(doc="Group 0", data_type_inc="Type0", name="type0") + child1 = DatasetSpec(doc="Group 1", data_type_inc="Type0", name="type1") + child2 = DatasetSpec(doc="Group 2", data_type_inc="Type0") parent_spec = GroupSpec( - doc='A test group', - name='parent', + doc="A test group", + name="parent", datasets=[child0, child1, child2], - data_type_def='ParentType' + data_type_def="ParentType", ) - self.assertIs(parent_spec.get_data_type('Type0'), child2) + self.assertIs(parent_spec.get_data_type("Type0"), child2) def test_unnamed_named_dataset_same_type(self): """Test get_data_type when a group contains both an unnamed and named dataset with type X.""" - child0 = DatasetSpec(doc='Group 0', data_type_inc='Type0') - child1 = DatasetSpec(doc='Group 1', data_type_inc='Type0', name='type1') + child0 = DatasetSpec(doc="Group 0", data_type_inc="Type0") + child1 = DatasetSpec(doc="Group 1", data_type_inc="Type0", name="type1") parent_spec = GroupSpec( - doc='A test group', - name='parent', + doc="A test group", + name="parent", datasets=[child0, child1], - data_type_def='ParentType' + data_type_def="ParentType", ) - self.assertIs(parent_spec.get_data_type('Type0'), child0) + self.assertIs(parent_spec.get_data_type("Type0"), child0) def test_two_named_datasets_same_type(self): """Test get_data_type when a group contains multiple named datasets with type X.""" - child0 = DatasetSpec(doc='Group 0', data_type_inc='Type0', name='group0') - child1 = DatasetSpec(doc='Group 1', data_type_inc='Type0', name='group1') + child0 = DatasetSpec(doc="Group 0", data_type_inc="Type0", name="group0") + child1 = DatasetSpec(doc="Group 1", data_type_inc="Type0", name="group1") parent_spec = GroupSpec( - doc='A test group', - name='parent', + doc="A test group", + name="parent", datasets=[child0, child1], - data_type_def='ParentType' + data_type_def="ParentType", ) - self.assertEqual(parent_spec.get_data_type('Type0'), [child0, child1]) + self.assertEqual(parent_spec.get_data_type("Type0"), [child0, child1]) def test_three_named_datasets_same_type(self): """Test get_target_type when a group contains three named links with type X.""" - child0 = DatasetSpec(doc='Group 0', data_type_inc='Type0', name='group0') - child1 = DatasetSpec(doc='Group 1', data_type_inc='Type0', name='group1') - child2 = DatasetSpec(doc='Group 2', data_type_inc='Type0', name='group2') + child0 = DatasetSpec(doc="Group 0", data_type_inc="Type0", name="group0") + child1 = DatasetSpec(doc="Group 1", data_type_inc="Type0", name="group1") + child2 = DatasetSpec(doc="Group 2", data_type_inc="Type0", name="group2") parent_spec = GroupSpec( - doc='A test group', - name='parent', + doc="A test group", + name="parent", datasets=[child0, child1, child2], - data_type_def='ParentType' + data_type_def="ParentType", ) - self.assertEqual(parent_spec.get_data_type('Type0'), [child0, child1, child2]) + self.assertEqual(parent_spec.get_data_type("Type0"), [child0, child1, child2]) def test_two_unnamed_links_same_type(self): """Test creating a group contains multiple unnamed links with type X.""" - child0 = LinkSpec(doc='Group 0', target_type='Type0') - child1 = LinkSpec(doc='Group 1', target_type='Type0') + child0 = LinkSpec(doc="Group 0", target_type="Type0") + child1 = LinkSpec(doc="Group 1", target_type="Type0") msg = "Cannot have multiple links with the same target type without specifying name" with self.assertRaisesWith(ValueError, msg): GroupSpec( - doc='A test group', - name='parent', + doc="A test group", + name="parent", links=[child0, child1], - data_type_def='ParentType' + data_type_def="ParentType", ) def test_named_unnamed_link_same_type(self): """Test get_target_type when a group contains both a named and unnamed link with type X.""" - child0 = LinkSpec(doc='Group 0', target_type='Type0', name='type0') - child1 = LinkSpec(doc='Group 1', target_type='Type0') + child0 = LinkSpec(doc="Group 0", target_type="Type0", name="type0") + child1 = LinkSpec(doc="Group 1", target_type="Type0") parent_spec = GroupSpec( - doc='A test group', - name='parent', + doc="A test group", + name="parent", links=[child0, child1], - data_type_def='ParentType' + data_type_def="ParentType", ) - self.assertIs(parent_spec.get_target_type('Type0'), child1) + self.assertIs(parent_spec.get_target_type("Type0"), child1) def test_two_named_unnamed_link_same_type(self): """Test get_target_type when a group contains two named and one unnamed link with type X.""" - child0 = LinkSpec(doc='Group 0', target_type='Type0', name='type0') - child1 = LinkSpec(doc='Group 1', target_type='Type0', name='type1') - child2 = LinkSpec(doc='Group 2', target_type='Type0') + child0 = LinkSpec(doc="Group 0", target_type="Type0", name="type0") + child1 = LinkSpec(doc="Group 1", target_type="Type0", name="type1") + child2 = LinkSpec(doc="Group 2", target_type="Type0") parent_spec = GroupSpec( - doc='A test group', - name='parent', + doc="A test group", + name="parent", links=[child0, child1, child2], - data_type_def='ParentType' + data_type_def="ParentType", ) - self.assertIs(parent_spec.get_target_type('Type0'), child2) + self.assertIs(parent_spec.get_target_type("Type0"), child2) def test_unnamed_named_link_same_type(self): """Test get_target_type when a group contains both an unnamed and named link with type X.""" - child0 = LinkSpec(doc='Group 0', target_type='Type0') - child1 = LinkSpec(doc='Group 1', target_type='Type0', name='type1') + child0 = LinkSpec(doc="Group 0", target_type="Type0") + child1 = LinkSpec(doc="Group 1", target_type="Type0", name="type1") parent_spec = GroupSpec( - doc='A test group', - name='parent', + doc="A test group", + name="parent", links=[child0, child1], - data_type_def='ParentType' + data_type_def="ParentType", ) - self.assertIs(parent_spec.get_target_type('Type0'), child0) + self.assertIs(parent_spec.get_target_type("Type0"), child0) def test_two_named_links_same_type(self): """Test get_target_type when a group contains multiple named links with type X.""" - child0 = LinkSpec(doc='Group 0', target_type='Type0', name='group0') - child1 = LinkSpec(doc='Group 1', target_type='Type0', name='group1') + child0 = LinkSpec(doc="Group 0", target_type="Type0", name="group0") + child1 = LinkSpec(doc="Group 1", target_type="Type0", name="group1") parent_spec = GroupSpec( - doc='A test group', - name='parent', + doc="A test group", + name="parent", links=[child0, child1], - data_type_def='ParentType' + data_type_def="ParentType", ) - self.assertEqual(parent_spec.get_target_type('Type0'), [child0, child1]) + self.assertEqual(parent_spec.get_target_type("Type0"), [child0, child1]) def test_three_named_links_same_type(self): """Test get_target_type when a group contains three named links with type X.""" - child0 = LinkSpec(doc='Group 0', target_type='Type0', name='type0') - child1 = LinkSpec(doc='Group 1', target_type='Type0', name='type1') - child2 = LinkSpec(doc='Group 2', target_type='Type0', name='type2') + child0 = LinkSpec(doc="Group 0", target_type="Type0", name="type0") + child1 = LinkSpec(doc="Group 1", target_type="Type0", name="type1") + child2 = LinkSpec(doc="Group 2", target_type="Type0", name="type2") parent_spec = GroupSpec( - doc='A test group', - name='parent', + doc="A test group", + name="parent", links=[child0, child1, child2], - data_type_def='ParentType' + data_type_def="ParentType", ) - self.assertEqual(parent_spec.get_target_type('Type0'), [child0, child1, child2]) + self.assertEqual(parent_spec.get_target_type("Type0"), [child0, child1, child2]) class SpecWithGroupsLinksTest(TestCase): - def test_unnamed_group_link_same_type(self): - child = GroupSpec(doc='Group 0', data_type_inc='Type0') - link = LinkSpec(doc='Link 0', target_type='Type0') + child = GroupSpec(doc="Group 0", data_type_inc="Type0") + link = LinkSpec(doc="Link 0", target_type="Type0") parent_spec = GroupSpec( - doc='A test group', - name='parent', + doc="A test group", + name="parent", groups=[child], links=[link], - data_type_def='ParentType' + data_type_def="ParentType", ) - self.assertIs(parent_spec.get_data_type('Type0'), child) - self.assertIs(parent_spec.get_target_type('Type0'), link) + self.assertIs(parent_spec.get_data_type("Type0"), child) + self.assertIs(parent_spec.get_target_type("Type0"), link) def test_unnamed_dataset_link_same_type(self): - child = DatasetSpec(doc='Dataset 0', data_type_inc='Type0') - link = LinkSpec(doc='Link 0', target_type='Type0') + child = DatasetSpec(doc="Dataset 0", data_type_inc="Type0") + link = LinkSpec(doc="Link 0", target_type="Type0") parent_spec = GroupSpec( - doc='A test group', - name='parent', + doc="A test group", + name="parent", datasets=[child], links=[link], - data_type_def='ParentType' + data_type_def="ParentType", ) - self.assertIs(parent_spec.get_data_type('Type0'), child) - self.assertIs(parent_spec.get_target_type('Type0'), link) + self.assertIs(parent_spec.get_data_type("Type0"), child) + self.assertIs(parent_spec.get_target_type("Type0"), link) diff --git a/tests/unit/spec_tests/test_link_spec.py b/tests/unit/spec_tests/test_link_spec.py index e6c680b7c..f6b57c1c5 100644 --- a/tests/unit/spec_tests/test_link_spec.py +++ b/tests/unit/spec_tests/test_link_spec.py @@ -5,65 +5,64 @@ class LinkSpecTests(TestCase): - def test_constructor(self): spec = LinkSpec( - doc='A test link', - target_type='Group1', - quantity='+', - name='Link1', + doc="A test link", + target_type="Group1", + quantity="+", + name="Link1", ) - self.assertEqual(spec.doc, 'A test link') - self.assertEqual(spec.target_type, 'Group1') - self.assertEqual(spec.data_type_inc, 'Group1') - self.assertEqual(spec.quantity, '+') - self.assertEqual(spec.name, 'Link1') + self.assertEqual(spec.doc, "A test link") + self.assertEqual(spec.target_type, "Group1") + self.assertEqual(spec.data_type_inc, "Group1") + self.assertEqual(spec.quantity, "+") + self.assertEqual(spec.name, "Link1") json.dumps(spec) def test_constructor_target_spec_def(self): group_spec_def = GroupSpec( - data_type_def='Group1', - doc='A test group', + data_type_def="Group1", + doc="A test group", ) spec = LinkSpec( - doc='A test link', + doc="A test link", target_type=group_spec_def, ) - self.assertEqual(spec.target_type, 'Group1') + self.assertEqual(spec.target_type, "Group1") json.dumps(spec) def test_constructor_target_spec_inc(self): group_spec_inc = GroupSpec( - data_type_inc='Group1', - doc='A test group', + data_type_inc="Group1", + doc="A test group", ) msg = "'target_type' must be a string or a GroupSpec or DatasetSpec with a 'data_type_def' key." with self.assertRaisesWith(ValueError, msg): LinkSpec( - doc='A test link', + doc="A test link", target_type=group_spec_inc, ) def test_constructor_defaults(self): spec = LinkSpec( - doc='A test link', - target_type='Group1', + doc="A test link", + target_type="Group1", ) self.assertEqual(spec.quantity, 1) self.assertIsNone(spec.name) json.dumps(spec) def test_required_is_many(self): - quantity_opts = ['?', 1, '*', '+'] + quantity_opts = ["?", 1, "*", "+"] is_required = [False, True, False, True] is_many = [False, False, True, True] - for (quantity, req, many) in zip(quantity_opts, is_required, is_many): + for quantity, req, many in zip(quantity_opts, is_required, is_many): with self.subTest(quantity=quantity): spec = LinkSpec( - doc='A test link', - target_type='Group1', + doc="A test link", + target_type="Group1", quantity=quantity, - name='Link1', + name="Link1", ) self.assertEqual(spec.required, req) self.assertEqual(spec.is_many(), many) diff --git a/tests/unit/spec_tests/test_load_namespace.py b/tests/unit/spec_tests/test_load_namespace.py index 5d7e6573c..e64e33bcd 100644 --- a/tests/unit/spec_tests/test_load_namespace.py +++ b/tests/unit/spec_tests/test_load_namespace.py @@ -1,91 +1,115 @@ import json import os -import ruamel.yaml as yaml -from tempfile import gettempdir import warnings +from tempfile import gettempdir + +import ruamel.yaml as yaml from hdmf.common import get_type_map -from hdmf.spec import AttributeSpec, DatasetSpec, GroupSpec, SpecNamespace, NamespaceCatalog, NamespaceBuilder +from hdmf.spec import ( + AttributeSpec, + DatasetSpec, + GroupSpec, + NamespaceBuilder, + NamespaceCatalog, + SpecNamespace, +) from hdmf.testing import TestCase, remove_test_file -from tests.unit.helpers.utils import CustomGroupSpec, CustomDatasetSpec, CustomSpecNamespace +from ..helpers.utils import CustomDatasetSpec, CustomGroupSpec, CustomSpecNamespace class TestSpecLoad(TestCase): - NS_NAME = 'test_ns' + NS_NAME = "test_ns" def setUp(self): self.attributes = [ - AttributeSpec('attribute1', 'my first attribute', 'text'), - AttributeSpec('attribute2', 'my second attribute', 'text') + AttributeSpec("attribute1", "my first attribute", "text"), + AttributeSpec("attribute2", "my second attribute", "text"), ] self.dset1_attributes = [ - AttributeSpec('attribute3', 'my third attribute', 'text'), - AttributeSpec('attribute4', 'my fourth attribute', 'text') + AttributeSpec("attribute3", "my third attribute", "text"), + AttributeSpec("attribute4", "my fourth attribute", "text"), ] self.dset2_attributes = [ - AttributeSpec('attribute5', 'my fifth attribute', 'text'), - AttributeSpec('attribute6', 'my sixth attribute', 'text') + AttributeSpec("attribute5", "my fifth attribute", "text"), + AttributeSpec("attribute6", "my sixth attribute", "text"), ] self.datasets = [ - DatasetSpec('my first dataset', - 'int', - name='dataset1', - attributes=self.dset1_attributes, - linkable=True), - DatasetSpec('my second dataset', - 'int', - name='dataset2', - dims=(None, None), - attributes=self.dset2_attributes, - linkable=True, - data_type_def='VoltageArray') + DatasetSpec( + "my first dataset", + "int", + name="dataset1", + attributes=self.dset1_attributes, + linkable=True, + ), + DatasetSpec( + "my second dataset", + "int", + name="dataset2", + dims=(None, None), + attributes=self.dset2_attributes, + linkable=True, + data_type_def="VoltageArray", + ), ] - self.spec = GroupSpec('A test group', - name='root_constructor_datatype', - datasets=self.datasets, - attributes=self.attributes, - linkable=False, - data_type_def='EphysData') + self.spec = GroupSpec( + "A test group", + name="root_constructor_datatype", + datasets=self.datasets, + attributes=self.attributes, + linkable=False, + data_type_def="EphysData", + ) dset1_attributes_ext = [ - AttributeSpec('dset1_extra_attribute', 'an extra attribute for the first dataset', 'text') + AttributeSpec( + "dset1_extra_attribute", + "an extra attribute for the first dataset", + "text", + ) ] self.ext_datasets = [ - DatasetSpec('my first dataset extension', - 'int', - name='dataset1', - attributes=dset1_attributes_ext, - linkable=True), + DatasetSpec( + "my first dataset extension", + "int", + name="dataset1", + attributes=dset1_attributes_ext, + linkable=True, + ), ] self.ext_attributes = [ - AttributeSpec('ext_extra_attribute', 'an extra attribute for the group', 'text'), + AttributeSpec( + "ext_extra_attribute", + "an extra attribute for the group", + "text", + ), ] - self.ext_spec = GroupSpec('A test group extension', - name='root_constructor_datatype', - datasets=self.ext_datasets, - attributes=self.ext_attributes, - linkable=False, - data_type_inc='EphysData', - data_type_def='SpikeData') - to_dump = {'groups': [self.spec, self.ext_spec]} - self.specs_path = 'test_load_namespace.specs.yaml' - self.namespace_path = 'test_load_namespace.namespace.yaml' - with open(self.specs_path, 'w') as tmp: - yaml_obj = yaml.YAML(typ='safe', pure=True) + self.ext_spec = GroupSpec( + "A test group extension", + name="root_constructor_datatype", + datasets=self.ext_datasets, + attributes=self.ext_attributes, + linkable=False, + data_type_inc="EphysData", + data_type_def="SpikeData", + ) + to_dump = {"groups": [self.spec, self.ext_spec]} + self.specs_path = "test_load_namespace.specs.yaml" + self.namespace_path = "test_load_namespace.namespace.yaml" + with open(self.specs_path, "w") as tmp: + yaml_obj = yaml.YAML(typ="safe", pure=True) yaml_obj.default_flow_style = False yaml_obj.dump(json.loads(json.dumps(to_dump)), tmp) ns_dict = { - 'doc': 'a test namespace', - 'name': self.NS_NAME, - 'schema': [ - {'source': self.specs_path} - ], - 'version': '0.1.0' + "doc": "a test namespace", + "name": self.NS_NAME, + "schema": [{"source": self.specs_path}], + "version": "0.1.0", } self.namespace = SpecNamespace.build_namespace(**ns_dict) - to_dump = {'namespaces': [self.namespace]} - with open(self.namespace_path, 'w') as tmp: - yaml_obj = yaml.YAML(typ='safe', pure=True) + to_dump = {"namespaces": [self.namespace]} + with open(self.namespace_path, "w") as tmp: + yaml_obj = yaml.YAML(typ="safe", pure=True) yaml_obj.default_flow_style = False yaml_obj.dump(json.loads(json.dumps(to_dump)), tmp) self.ns_catalog = NamespaceCatalog() @@ -98,8 +122,8 @@ def tearDown(self): def test_inherited_attributes(self): self.ns_catalog.load_namespaces(self.namespace_path, resolve=True) - ts_spec = self.ns_catalog.get_spec(self.NS_NAME, 'EphysData') - es_spec = self.ns_catalog.get_spec(self.NS_NAME, 'SpikeData') + ts_spec = self.ns_catalog.get_spec(self.NS_NAME, "EphysData") + es_spec = self.ns_catalog.get_spec(self.NS_NAME, "SpikeData") ts_attrs = {s.name for s in ts_spec.attributes} es_attrs = {s.name for s in es_spec.attributes} for attr in ts_attrs: @@ -115,7 +139,7 @@ def test_inherited_attributes(self): def test_inherited_attributes_not_resolved(self): self.ns_catalog.load_namespaces(self.namespace_path, resolve=False) - es_spec = self.ns_catalog.get_spec(self.NS_NAME, 'SpikeData') + es_spec = self.ns_catalog.get_spec(self.NS_NAME, "SpikeData") src_attrs = {s.name for s in self.ext_attributes} ext_attrs = {s.name for s in es_spec.attributes} self.assertSetEqual(src_attrs, ext_attrs) @@ -125,15 +149,14 @@ def test_inherited_attributes_not_resolved(self): class TestSpecLoadEdgeCase(TestCase): - def setUp(self): - self.specs_path = 'test_load_namespace.specs.yaml' - self.namespace_path = 'test_load_namespace.namespace.yaml' + self.specs_path = "test_load_namespace.specs.yaml" + self.namespace_path = "test_load_namespace.namespace.yaml" # write basically empty specs file - to_dump = {'groups': []} - with open(self.specs_path, 'w') as tmp: - yaml_obj = yaml.YAML(typ='safe', pure=True) + to_dump = {"groups": []} + with open(self.specs_path, "w") as tmp: + yaml_obj = yaml.YAML(typ="safe", pure=True) yaml_obj.default_flow_style = False yaml_obj.dump(json.loads(json.dumps(to_dump)), tmp) @@ -145,14 +168,15 @@ def test_build_namespace_missing_version(self): """Test that building/creating a SpecNamespace without a version works but raises a warning.""" # create namespace without version key ns_dict = { - 'doc': 'a test namespace', - 'name': 'test_ns', - 'schema': [ - {'source': self.specs_path} - ], + "doc": "a test namespace", + "name": "test_ns", + "schema": [{"source": self.specs_path}], } - msg = ("Loaded namespace 'test_ns' is missing the required key 'version'. Version will be set to " - "'%s'. Please notify the extension author." % SpecNamespace.UNVERSIONED) + msg = ( + "Loaded namespace 'test_ns' is missing the required key 'version'. Version" + " will be set to '%s'. Please notify the extension author." + % SpecNamespace.UNVERSIONED + ) with self.assertWarnsWith(UserWarning, msg): namespace = SpecNamespace.build_namespace(**ns_dict) @@ -162,50 +186,52 @@ def test_load_namespace_none_version(self): """Test that reading a namespace file without a version works but raises a warning.""" # create namespace with version key (remove it later) ns_dict = { - 'doc': 'a test namespace', - 'name': 'test_ns', - 'schema': [ - {'source': self.specs_path} - ], - 'version': '0.0.1' + "doc": "a test namespace", + "name": "test_ns", + "schema": [{"source": self.specs_path}], + "version": "0.0.1", } namespace = SpecNamespace.build_namespace(**ns_dict) - namespace['version'] = None # work around lack of setter to remove version key + namespace["version"] = None # work around lack of setter to remove version key # write the namespace to file without version key - to_dump = {'namespaces': [namespace]} - with open(self.namespace_path, 'w') as tmp: - yaml_obj = yaml.YAML(typ='safe', pure=True) + to_dump = {"namespaces": [namespace]} + with open(self.namespace_path, "w") as tmp: + yaml_obj = yaml.YAML(typ="safe", pure=True) yaml_obj.default_flow_style = False yaml_obj.dump(json.loads(json.dumps(to_dump)), tmp) # load the namespace from file ns_catalog = NamespaceCatalog() - msg = ("Loaded namespace 'test_ns' is missing the required key 'version'. Version will be set to " - "'%s'. Please notify the extension author." % SpecNamespace.UNVERSIONED) + msg = ( + "Loaded namespace 'test_ns' is missing the required key 'version'. Version" + " will be set to '%s'. Please notify the extension author." + % SpecNamespace.UNVERSIONED + ) with self.assertWarnsWith(UserWarning, msg): ns_catalog.load_namespaces(self.namespace_path) - self.assertEqual(ns_catalog.get_namespace('test_ns').version, SpecNamespace.UNVERSIONED) + self.assertEqual( + ns_catalog.get_namespace("test_ns").version, + SpecNamespace.UNVERSIONED, + ) def test_load_namespace_unversioned_version(self): """Test that reading a namespace file with version=unversioned string works but raises a warning.""" # create namespace with version key (remove it later) ns_dict = { - 'doc': 'a test namespace', - 'name': 'test_ns', - 'schema': [ - {'source': self.specs_path} - ], - 'version': '0.0.1' + "doc": "a test namespace", + "name": "test_ns", + "schema": [{"source": self.specs_path}], + "version": "0.0.1", } namespace = SpecNamespace.build_namespace(**ns_dict) - namespace['version'] = str(SpecNamespace.UNVERSIONED) # work around lack of setter to remove version key + namespace["version"] = str(SpecNamespace.UNVERSIONED) # work around lack of setter to remove version key # write the namespace to file without version key - to_dump = {'namespaces': [namespace]} - with open(self.namespace_path, 'w') as tmp: - yaml_obj = yaml.YAML(typ='safe', pure=True) + to_dump = {"namespaces": [namespace]} + with open(self.namespace_path, "w") as tmp: + yaml_obj = yaml.YAML(typ="safe", pure=True) yaml_obj.default_flow_style = False yaml_obj.dump(json.loads(json.dumps(to_dump)), tmp) @@ -215,7 +241,10 @@ def test_load_namespace_unversioned_version(self): with self.assertWarnsWith(UserWarning, msg): ns_catalog.load_namespaces(self.namespace_path) - self.assertEqual(ns_catalog.get_namespace('test_ns').version, SpecNamespace.UNVERSIONED) + self.assertEqual( + ns_catalog.get_namespace("test_ns").version, + SpecNamespace.UNVERSIONED, + ) def test_missing_version_string(self): """Test that the constant variable representing a missing version has not changed.""" @@ -225,38 +254,40 @@ def test_get_namespace_missing_version(self): """Test that SpecNamespace.version returns the constant for a missing version if version gets removed.""" # create namespace with version key (remove it later) ns_dict = { - 'doc': 'a test namespace', - 'name': 'test_ns', - 'schema': [ - {'source': self.specs_path} - ], - 'version': '0.0.1' + "doc": "a test namespace", + "name": "test_ns", + "schema": [{"source": self.specs_path}], + "version": "0.0.1", } namespace = SpecNamespace.build_namespace(**ns_dict) - namespace['version'] = None # work around lack of setter to remove version key + namespace["version"] = None # work around lack of setter to remove version key self.assertEqual(namespace.version, SpecNamespace.UNVERSIONED) class TestCatchDupNS(TestCase): - def setUp(self): self.tempdir = gettempdir() - self.ext_source1 = 'extension1.yaml' - self.ns_path1 = 'namespace1.yaml' - self.ext_source2 = 'extension2.yaml' - self.ns_path2 = 'namespace2.yaml' + self.ext_source1 = "extension1.yaml" + self.ns_path1 = "namespace1.yaml" + self.ext_source2 = "extension2.yaml" + self.ns_path2 = "namespace2.yaml" def tearDown(self): - for f in (self.ext_source1, self.ns_path1, self.ext_source2, self.ns_path2): + for f in ( + self.ext_source1, + self.ns_path1, + self.ext_source2, + self.ns_path2, + ): remove_test_file(os.path.join(self.tempdir, f)) def test_catch_dup_name(self): - ns_builder1 = NamespaceBuilder('Extension doc', "test_ext", version='0.1.0') - ns_builder1.add_spec(self.ext_source1, GroupSpec('doc', data_type_def='MyType')) + ns_builder1 = NamespaceBuilder("Extension doc", "test_ext", version="0.1.0") + ns_builder1.add_spec(self.ext_source1, GroupSpec("doc", data_type_def="MyType")) ns_builder1.export(self.ns_path1, outdir=self.tempdir) - ns_builder2 = NamespaceBuilder('Extension doc', "test_ext", version='0.2.0') - ns_builder2.add_spec(self.ext_source2, GroupSpec('doc', data_type_def='MyType')) + ns_builder2 = NamespaceBuilder("Extension doc", "test_ext", version="0.2.0") + ns_builder2.add_spec(self.ext_source2, GroupSpec("doc", data_type_def="MyType")) ns_builder2.export(self.ns_path2, outdir=self.tempdir) ns_catalog = NamespaceCatalog() @@ -267,11 +298,11 @@ def test_catch_dup_name(self): ns_catalog.load_namespaces(os.path.join(self.tempdir, self.ns_path2)) def test_catch_dup_name_same_version(self): - ns_builder1 = NamespaceBuilder('Extension doc', "test_ext", version='0.1.0') - ns_builder1.add_spec(self.ext_source1, GroupSpec('doc', data_type_def='MyType')) + ns_builder1 = NamespaceBuilder("Extension doc", "test_ext", version="0.1.0") + ns_builder1.add_spec(self.ext_source1, GroupSpec("doc", data_type_def="MyType")) ns_builder1.export(self.ns_path1, outdir=self.tempdir) - ns_builder2 = NamespaceBuilder('Extension doc', "test_ext", version='0.1.0') - ns_builder2.add_spec(self.ext_source2, GroupSpec('doc', data_type_def='MyType')) + ns_builder2 = NamespaceBuilder("Extension doc", "test_ext", version="0.1.0") + ns_builder2.add_spec(self.ext_source2, GroupSpec("doc", data_type_def="MyType")) ns_builder2.export(self.ns_path2, outdir=self.tempdir) ns_catalog = NamespaceCatalog() @@ -286,7 +317,6 @@ def test_catch_dup_name_same_version(self): class TestCustomSpecClasses(TestCase): - def setUp(self): # noqa: C901 self.ns_catalog = NamespaceCatalog(CustomGroupSpec, CustomDatasetSpec, CustomSpecNamespace) hdmf_typemap = get_type_map() @@ -298,62 +328,91 @@ def test_constructor_getters(self): self.assertEqual(self.ns_catalog.spec_namespace_cls, CustomSpecNamespace) def test_load_namespaces(self): - namespace_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'test.namespace.yaml') + namespace_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "test.namespace.yaml") namespace_deps = self.ns_catalog.load_namespaces(namespace_path) # test that the dependencies are correct, including dependencies of the dependencies - expected = set(['Data', 'Container', 'DynamicTable', 'ElementIdentifiers', 'VectorData']) - self.assertSetEqual(set(namespace_deps['test']['hdmf-common']), expected) + expected = set( + [ + "Data", + "Container", + "DynamicTable", + "ElementIdentifiers", + "VectorData", + ] + ) + self.assertSetEqual(set(namespace_deps["test"]["hdmf-common"]), expected) # test that the types are loaded - types = self.ns_catalog.get_types('test.base.yaml') - expected = ('TestData', 'TestContainer', 'TestTable') + types = self.ns_catalog.get_types("test.base.yaml") + expected = ("TestData", "TestContainer", "TestTable") self.assertTupleEqual(types, expected) # test that the namespace is correct and the types_key is updated for test ns - test_namespace = self.ns_catalog.get_namespace('test') - expected = {'doc': 'Test namespace', - 'schema': [{'namespace': 'hdmf-common', - 'my_data_types': ['Data', 'DynamicTable', 'Container']}, - {'doc': 'This source module contains base data types.', - 'source': 'test.base.yaml', - 'title': 'Base data types'}], - 'name': 'test', - 'full_name': 'Test', - 'version': '0.1.0', - 'author': ['Test test'], - 'contact': ['test@test.com']} + test_namespace = self.ns_catalog.get_namespace("test") + expected = { + "doc": "Test namespace", + "schema": [ + { + "namespace": "hdmf-common", + "my_data_types": ["Data", "DynamicTable", "Container"], + }, + { + "doc": "This source module contains base data types.", + "source": "test.base.yaml", + "title": "Base data types", + }, + ], + "name": "test", + "full_name": "Test", + "version": "0.1.0", + "author": ["Test test"], + "contact": ["test@test.com"], + } self.assertDictEqual(test_namespace, expected) # test that the def_key is updated for test ns - test_data_spec = self.ns_catalog.get_spec('test', 'TestData') - self.assertTrue('my_data_type_def' in test_data_spec) - self.assertTrue('my_data_type_inc' in test_data_spec) + test_data_spec = self.ns_catalog.get_spec("test", "TestData") + self.assertTrue("my_data_type_def" in test_data_spec) + self.assertTrue("my_data_type_inc" in test_data_spec) # test that the def_key is maintained for hdmf-common - data_spec = self.ns_catalog.get_spec('hdmf-common', 'Data') - self.assertTrue('data_type_def' in data_spec) + data_spec = self.ns_catalog.get_spec("hdmf-common", "Data") + self.assertTrue("data_type_def" in data_spec) def test_load_namespaces_ext(self): - namespace_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'test.namespace.yaml') + namespace_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "test.namespace.yaml") self.ns_catalog.load_namespaces(namespace_path) - ext_namespace_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'test-ext.namespace.yaml') + ext_namespace_path = os.path.join( + os.path.dirname(os.path.realpath(__file__)), + "test-ext.namespace.yaml", + ) ext_namespace_deps = self.ns_catalog.load_namespaces(ext_namespace_path) # test that the dependencies are correct, including dependencies of the dependencies - expected_deps = set(['TestData', 'TestContainer', 'TestTable', 'Container', 'Data', 'DynamicTable', - 'ElementIdentifiers', 'VectorData']) - self.assertSetEqual(set(ext_namespace_deps['test-ext']['test']), expected_deps) + expected_deps = set( + [ + "TestData", + "TestContainer", + "TestTable", + "Container", + "Data", + "DynamicTable", + "ElementIdentifiers", + "VectorData", + ] + ) + self.assertSetEqual(set(ext_namespace_deps["test-ext"]["test"]), expected_deps) def test_load_namespaces_bad_path(self): - namespace_path = 'test.namespace.yaml' + namespace_path = "test.namespace.yaml" msg = "namespace file 'test.namespace.yaml' not found" with self.assertRaisesWith(IOError, msg): self.ns_catalog.load_namespaces(namespace_path) def test_load_namespaces_twice(self): - namespace_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'test.namespace.yaml') + namespace_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "test.namespace.yaml") namespace_deps1 = self.ns_catalog.load_namespaces(namespace_path) namespace_deps2 = self.ns_catalog.load_namespaces(namespace_path) self.assertDictEqual(namespace_deps1, namespace_deps2) diff --git a/tests/unit/spec_tests/test_ref_spec.py b/tests/unit/spec_tests/test_ref_spec.py index bb1c0efb8..9cfc99467 100644 --- a/tests/unit/spec_tests/test_ref_spec.py +++ b/tests/unit/spec_tests/test_ref_spec.py @@ -5,19 +5,18 @@ class RefSpecTests(TestCase): - def test_constructor(self): - spec = RefSpec('TimeSeries', 'object') - self.assertEqual(spec.target_type, 'TimeSeries') - self.assertEqual(spec.reftype, 'object') + spec = RefSpec("TimeSeries", "object") + self.assertEqual(spec.target_type, "TimeSeries") + self.assertEqual(spec.reftype, "object") json.dumps(spec) # to ensure there are no circular links def test_wrong_reference_type(self): with self.assertRaises(ValueError): - RefSpec('TimeSeries', 'unknownreftype') + RefSpec("TimeSeries", "unknownreftype") def test_isregion(self): - spec = RefSpec('TimeSeries', 'object') + spec = RefSpec("TimeSeries", "object") self.assertFalse(spec.is_region()) - spec = RefSpec('Data', 'region') + spec = RefSpec("Data", "region") self.assertTrue(spec.is_region()) diff --git a/tests/unit/spec_tests/test_spec_catalog.py b/tests/unit/spec_tests/test_spec_catalog.py index d12f352e8..675528940 100644 --- a/tests/unit/spec_tests/test_spec_catalog.py +++ b/tests/unit/spec_tests/test_spec_catalog.py @@ -1,122 +1,138 @@ import copy -from hdmf.spec import GroupSpec, DatasetSpec, AttributeSpec, SpecCatalog +from hdmf.spec import AttributeSpec, DatasetSpec, GroupSpec, SpecCatalog from hdmf.testing import TestCase class SpecCatalogTest(TestCase): - def setUp(self): self.catalog = SpecCatalog() self.attributes = [ - AttributeSpec('attribute1', 'my first attribute', 'text'), - AttributeSpec('attribute2', 'my second attribute', 'text') + AttributeSpec("attribute1", "my first attribute", "text"), + AttributeSpec("attribute2", "my second attribute", "text"), ] - self.spec = DatasetSpec('my first dataset', - 'int', - name='dataset1', - dims=(None, None), - attributes=self.attributes, - linkable=False, - data_type_def='EphysData') + self.spec = DatasetSpec( + "my first dataset", + "int", + name="dataset1", + dims=(None, None), + attributes=self.attributes, + linkable=False, + data_type_def="EphysData", + ) def test_register_spec(self): - self.catalog.register_spec(self.spec, 'test.yaml') - result = self.catalog.get_spec('EphysData') + self.catalog.register_spec(self.spec, "test.yaml") + result = self.catalog.get_spec("EphysData") self.assertIs(result, self.spec) def test_hierarchy(self): - spikes_spec = DatasetSpec('my extending dataset', 'int', - data_type_inc='EphysData', - data_type_def='SpikeData') + spikes_spec = DatasetSpec( + "my extending dataset", + "int", + data_type_inc="EphysData", + data_type_def="SpikeData", + ) - lfp_spec = DatasetSpec('my second extending dataset', 'int', - data_type_inc='EphysData', - data_type_def='LFPData') + lfp_spec = DatasetSpec( + "my second extending dataset", + "int", + data_type_inc="EphysData", + data_type_def="LFPData", + ) - self.catalog.register_spec(self.spec, 'test.yaml') - self.catalog.register_spec(spikes_spec, 'test.yaml') - self.catalog.register_spec(lfp_spec, 'test.yaml') + self.catalog.register_spec(self.spec, "test.yaml") + self.catalog.register_spec(spikes_spec, "test.yaml") + self.catalog.register_spec(lfp_spec, "test.yaml") - spike_hierarchy = self.catalog.get_hierarchy('SpikeData') - lfp_hierarchy = self.catalog.get_hierarchy('LFPData') - ephys_hierarchy = self.catalog.get_hierarchy('EphysData') - self.assertTupleEqual(spike_hierarchy, ('SpikeData', 'EphysData')) - self.assertTupleEqual(lfp_hierarchy, ('LFPData', 'EphysData')) - self.assertTupleEqual(ephys_hierarchy, ('EphysData',)) + spike_hierarchy = self.catalog.get_hierarchy("SpikeData") + lfp_hierarchy = self.catalog.get_hierarchy("LFPData") + ephys_hierarchy = self.catalog.get_hierarchy("EphysData") + self.assertTupleEqual(spike_hierarchy, ("SpikeData", "EphysData")) + self.assertTupleEqual(lfp_hierarchy, ("LFPData", "EphysData")) + self.assertTupleEqual(ephys_hierarchy, ("EphysData",)) def test_subtypes(self): """ - -BaseContainer--+-->AContainer--->ADContainer - | - +-->BContainer + -BaseContainer--+-->AContainer--->ADContainer + | + +-->BContainer """ - base_spec = GroupSpec(doc='Base container', - data_type_def='BaseContainer') - acontainer = GroupSpec(doc='AContainer', - data_type_inc='BaseContainer', - data_type_def='AContainer') - adcontainer = GroupSpec(doc='ADContainer', - data_type_inc='AContainer', - data_type_def='ADContainer') - bcontainer = GroupSpec(doc='BContainer', - data_type_inc='BaseContainer', - data_type_def='BContainer') - self.catalog.register_spec(base_spec, 'test.yaml') - self.catalog.register_spec(acontainer, 'test.yaml') - self.catalog.register_spec(adcontainer, 'test.yaml') - self.catalog.register_spec(bcontainer, 'test.yaml') - base_spec_subtypes = self.catalog.get_subtypes('BaseContainer') + base_spec = GroupSpec(doc="Base container", data_type_def="BaseContainer") + acontainer = GroupSpec( + doc="AContainer", + data_type_inc="BaseContainer", + data_type_def="AContainer", + ) + adcontainer = GroupSpec( + doc="ADContainer", + data_type_inc="AContainer", + data_type_def="ADContainer", + ) + bcontainer = GroupSpec( + doc="BContainer", + data_type_inc="BaseContainer", + data_type_def="BContainer", + ) + self.catalog.register_spec(base_spec, "test.yaml") + self.catalog.register_spec(acontainer, "test.yaml") + self.catalog.register_spec(adcontainer, "test.yaml") + self.catalog.register_spec(bcontainer, "test.yaml") + base_spec_subtypes = self.catalog.get_subtypes("BaseContainer") base_spec_subtypes = tuple(sorted(base_spec_subtypes)) # Sort so we have a guaranteed order for comparison - acontainer_subtypes = self.catalog.get_subtypes('AContainer') - bcontainer_substypes = self.catalog.get_subtypes('BContainer') - adcontainer_subtypes = self.catalog.get_subtypes('ADContainer') + acontainer_subtypes = self.catalog.get_subtypes("AContainer") + bcontainer_substypes = self.catalog.get_subtypes("BContainer") + adcontainer_subtypes = self.catalog.get_subtypes("ADContainer") self.assertTupleEqual(adcontainer_subtypes, ()) self.assertTupleEqual(bcontainer_substypes, ()) - self.assertTupleEqual(acontainer_subtypes, ('ADContainer',)) - self.assertTupleEqual(base_spec_subtypes, ('AContainer', 'ADContainer', 'BContainer')) + self.assertTupleEqual(acontainer_subtypes, ("ADContainer",)) + self.assertTupleEqual(base_spec_subtypes, ("AContainer", "ADContainer", "BContainer")) def test_subtypes_norecursion(self): """ - -BaseContainer--+-->AContainer--->ADContainer - | - +-->BContainer + -BaseContainer--+-->AContainer--->ADContainer + | + +-->BContainer """ - base_spec = GroupSpec(doc='Base container', - data_type_def='BaseContainer') - acontainer = GroupSpec(doc='AContainer', - data_type_inc='BaseContainer', - data_type_def='AContainer') - adcontainer = GroupSpec(doc='ADContainer', - data_type_inc='AContainer', - data_type_def='ADContainer') - bcontainer = GroupSpec(doc='BContainer', - data_type_inc='BaseContainer', - data_type_def='BContainer') - self.catalog.register_spec(base_spec, 'test.yaml') - self.catalog.register_spec(acontainer, 'test.yaml') - self.catalog.register_spec(adcontainer, 'test.yaml') - self.catalog.register_spec(bcontainer, 'test.yaml') - base_spec_subtypes = self.catalog.get_subtypes('BaseContainer', recursive=False) + base_spec = GroupSpec(doc="Base container", data_type_def="BaseContainer") + acontainer = GroupSpec( + doc="AContainer", + data_type_inc="BaseContainer", + data_type_def="AContainer", + ) + adcontainer = GroupSpec( + doc="ADContainer", + data_type_inc="AContainer", + data_type_def="ADContainer", + ) + bcontainer = GroupSpec( + doc="BContainer", + data_type_inc="BaseContainer", + data_type_def="BContainer", + ) + self.catalog.register_spec(base_spec, "test.yaml") + self.catalog.register_spec(acontainer, "test.yaml") + self.catalog.register_spec(adcontainer, "test.yaml") + self.catalog.register_spec(bcontainer, "test.yaml") + base_spec_subtypes = self.catalog.get_subtypes("BaseContainer", recursive=False) base_spec_subtypes = tuple(sorted(base_spec_subtypes)) # Sort so we have a guaranteed order for comparison - acontainer_subtypes = self.catalog.get_subtypes('AContainer', recursive=False) - bcontainer_substypes = self.catalog.get_subtypes('BContainer', recursive=False) - adcontainer_subtypes = self.catalog.get_subtypes('ADContainer', recursive=False) + acontainer_subtypes = self.catalog.get_subtypes("AContainer", recursive=False) + bcontainer_substypes = self.catalog.get_subtypes("BContainer", recursive=False) + adcontainer_subtypes = self.catalog.get_subtypes("ADContainer", recursive=False) self.assertTupleEqual(adcontainer_subtypes, ()) self.assertTupleEqual(bcontainer_substypes, ()) - self.assertTupleEqual(acontainer_subtypes, ('ADContainer',)) - self.assertTupleEqual(base_spec_subtypes, ('AContainer', 'BContainer')) + self.assertTupleEqual(acontainer_subtypes, ("ADContainer",)) + self.assertTupleEqual(base_spec_subtypes, ("AContainer", "BContainer")) def test_subtypes_unknown_type(self): - subtypes_of_bad_type = self.catalog.get_subtypes('UnknownType') + subtypes_of_bad_type = self.catalog.get_subtypes("UnknownType") self.assertTupleEqual(subtypes_of_bad_type, ()) def test_get_spec_source_file(self): - spikes_spec = GroupSpec('test group', - data_type_def='SpikeData') - source_file_path = '/test/myt/test.yaml' + spikes_spec = GroupSpec("test group", data_type_def="SpikeData") + source_file_path = "/test/myt/test.yaml" self.catalog.auto_register(spikes_spec, source_file_path) - recorded_source_file_path = self.catalog.get_spec_source_file('SpikeData') + recorded_source_file_path = self.catalog.get_spec_source_file("SpikeData") self.assertEqual(recorded_source_file_path, source_file_path) def test_get_full_hierarchy(self): @@ -135,78 +151,79 @@ def test_get_full_hierarchy(self): >> "BContainer": {} >> } """ - base_spec = GroupSpec(doc='Base container', - data_type_def='BaseContainer') - acontainer = GroupSpec(doc='AContainer', - data_type_inc='BaseContainer', - data_type_def='AContainer') - adcontainer = GroupSpec(doc='ADContainer', - data_type_inc='AContainer', - data_type_def='ADContainer') - bcontainer = GroupSpec(doc='BContainer', - data_type_inc='BaseContainer', - data_type_def='BContainer') - self.catalog.register_spec(base_spec, 'test.yaml') - self.catalog.register_spec(acontainer, 'test.yaml') - self.catalog.register_spec(adcontainer, 'test.yaml') - self.catalog.register_spec(bcontainer, 'test.yaml') + base_spec = GroupSpec(doc="Base container", data_type_def="BaseContainer") + acontainer = GroupSpec( + doc="AContainer", + data_type_inc="BaseContainer", + data_type_def="AContainer", + ) + adcontainer = GroupSpec( + doc="ADContainer", + data_type_inc="AContainer", + data_type_def="ADContainer", + ) + bcontainer = GroupSpec( + doc="BContainer", + data_type_inc="BaseContainer", + data_type_def="BContainer", + ) + self.catalog.register_spec(base_spec, "test.yaml") + self.catalog.register_spec(acontainer, "test.yaml") + self.catalog.register_spec(adcontainer, "test.yaml") + self.catalog.register_spec(bcontainer, "test.yaml") full_hierarchy = self.catalog.get_full_hierarchy() expected_hierarchy = { - "BaseContainer": { - "AContainer": { - "ADContainer": {} - }, - "BContainer": {} - } - } + "BaseContainer": { + "AContainer": {"ADContainer": {}}, + "BContainer": {}, + } + } self.assertDictEqual(full_hierarchy, expected_hierarchy) def test_copy_spec_catalog(self): # Register the spec first - self.catalog.register_spec(self.spec, 'test.yaml') - result = self.catalog.get_spec('EphysData') + self.catalog.register_spec(self.spec, "test.yaml") + result = self.catalog.get_spec("EphysData") self.assertIs(result, self.spec) # Now test the copy re = copy.copy(self.catalog) - self.assertTupleEqual(self.catalog.get_registered_types(), - re.get_registered_types()) + self.assertTupleEqual(self.catalog.get_registered_types(), re.get_registered_types()) def test_deepcopy_spec_catalog(self): # Register the spec first - self.catalog.register_spec(self.spec, 'test.yaml') - result = self.catalog.get_spec('EphysData') + self.catalog.register_spec(self.spec, "test.yaml") + result = self.catalog.get_spec("EphysData") self.assertIs(result, self.spec) # Now test the copy re = copy.deepcopy(self.catalog) - self.assertTupleEqual(self.catalog.get_registered_types(), - re.get_registered_types()) + self.assertTupleEqual(self.catalog.get_registered_types(), re.get_registered_types()) def test_catch_duplicate_spec_nested(self): spec1 = GroupSpec( - data_type_def='Group1', - doc='This is my new group 1', + data_type_def="Group1", + doc="This is my new group 1", ) spec2 = GroupSpec( - data_type_def='Group2', - doc='This is my new group 2', + data_type_def="Group2", + doc="This is my new group 2", groups=[spec1], # nested definition ) - source = 'test_extension.yaml' + source = "test_extension.yaml" self.catalog.register_spec(spec1, source) self.catalog.register_spec(spec2, source) # this is OK because Group1 is the same spec - ret = self.catalog.get_spec('Group1') + ret = self.catalog.get_spec("Group1") self.assertIs(ret, spec1) def test_catch_duplicate_spec_different(self): spec1 = GroupSpec( - data_type_def='Group1', - doc='This is my new group 1', + data_type_def="Group1", + doc="This is my new group 1", ) spec2 = GroupSpec( - data_type_def='Group1', - doc='This is my other group 1', + data_type_def="Group1", + doc="This is my other group 1", ) - source = 'test_extension.yaml' + source = "test_extension.yaml" self.catalog.register_spec(spec1, source) msg = "'Group1' - cannot overwrite existing specification" with self.assertRaisesWith(ValueError, msg): @@ -214,15 +231,15 @@ def test_catch_duplicate_spec_different(self): def test_catch_duplicate_spec_different_source(self): spec1 = GroupSpec( - data_type_def='Group1', - doc='This is my new group 1', + data_type_def="Group1", + doc="This is my new group 1", ) spec2 = GroupSpec( - data_type_def='Group1', - doc='This is my new group 1', + data_type_def="Group1", + doc="This is my new group 1", ) - source1 = 'test_extension1.yaml' - source2 = 'test_extension2.yaml' + source1 = "test_extension1.yaml" + source2 = "test_extension2.yaml" self.catalog.register_spec(spec1, source1) msg = "'Group1' - cannot overwrite existing specification" with self.assertRaisesWith(ValueError, msg): diff --git a/tests/unit/spec_tests/test_spec_write.py b/tests/unit/spec_tests/test_spec_write.py index a9410df2a..0211f4ef4 100644 --- a/tests/unit/spec_tests/test_spec_write.py +++ b/tests/unit/spec_tests/test_spec_write.py @@ -1,55 +1,58 @@ import datetime import os -from hdmf.spec.namespace import SpecNamespace, NamespaceCatalog +from hdmf.spec.namespace import NamespaceCatalog, SpecNamespace from hdmf.spec.spec import GroupSpec from hdmf.spec.write import NamespaceBuilder, YAMLSpecWriter, export_spec from hdmf.testing import TestCase class TestSpec(TestCase): - def setUp(self): # create a builder for the namespace self.ns_name = "mylab" self.date = datetime.datetime.now() - self.ns_builder = NamespaceBuilder(doc="mydoc", - name=self.ns_name, - full_name="My Laboratory", - version="0.0.1", - author="foo", - contact="foo@bar.com", - namespace_cls=SpecNamespace, - date=self.date) + self.ns_builder = NamespaceBuilder( + doc="mydoc", + name=self.ns_name, + full_name="My Laboratory", + version="0.0.1", + author="foo", + contact="foo@bar.com", + namespace_cls=SpecNamespace, + date=self.date, + ) # create extensions - ext1 = GroupSpec('A custom DataSeries interface', - attributes=[], - datasets=[], - groups=[], - data_type_inc=None, - data_type_def='MyDataSeries') - - ext2 = GroupSpec('An extension of a DataSeries interface', - attributes=[], - datasets=[], - groups=[], - data_type_inc='MyDataSeries', - data_type_def='MyExtendedMyDataSeries') - - ext2.add_dataset(doc='test', - dtype='float', - name='testdata') + ext1 = GroupSpec( + "A custom DataSeries interface", + attributes=[], + datasets=[], + groups=[], + data_type_inc=None, + data_type_def="MyDataSeries", + ) + + ext2 = GroupSpec( + "An extension of a DataSeries interface", + attributes=[], + datasets=[], + groups=[], + data_type_inc="MyDataSeries", + data_type_def="MyExtendedMyDataSeries", + ) + + ext2.add_dataset(doc="test", dtype="float", name="testdata") self.data_types = [ext1, ext2] # add the extension - self.ext_source_path = 'mylab.extensions.yaml' - self.namespace_path = 'mylab.namespace.yaml' + self.ext_source_path = "mylab.extensions.yaml" + self.namespace_path = "mylab.namespace.yaml" def _test_extensions_file(self): - with open(self.ext_source_path, 'r') as file: + with open(self.ext_source_path, "r") as file: match_str = """groups: - data_type_def: MyDataSeries doc: A custom DataSeries interface @@ -65,7 +68,7 @@ def _test_extensions_file(self): self.assertEqual(nsstr, match_str) def _test_namespace_file(self): - with open(self.namespace_path, 'r') as file: + with open(self.namespace_path, "r") as file: match_str = """namespaces: - author: foo contact: foo@bar.com @@ -78,21 +81,23 @@ def _test_namespace_file(self): source: mylab.extensions.yaml title: Extensions for my lab version: 0.0.1 -""" % self.date.isoformat() # noqa: E122 +""" % self.date.isoformat() nsstr = file.read() self.assertEqual(nsstr, match_str) class TestNamespaceBuilder(TestSpec): - NS_NAME = 'test_ns' + NS_NAME = "test_ns" def setUp(self): super().setUp() for data_type in self.data_types: self.ns_builder.add_spec(source=self.ext_source_path, spec=data_type) - self.ns_builder.add_source(source=self.ext_source_path, - doc='Extensions for my lab', - title='My lab extensions') + self.ns_builder.add_source( + source=self.ext_source_path, + doc="Extensions for my lab", + title="My lab extensions", + ) self.ns_builder.export(self.namespace_path) def tearDown(self): @@ -115,48 +120,61 @@ def test_read_namespace(self): self.assertEqual(loaded_ns.full_name, "My Laboratory") self.assertEqual(loaded_ns.name, "mylab") self.assertEqual(loaded_ns.date, self.date.isoformat()) - self.assertDictEqual(loaded_ns.schema[0], {'doc': 'Extensions for my lab', - 'source': 'mylab.extensions.yaml', - 'title': 'Extensions for my lab'}) + self.assertDictEqual( + loaded_ns.schema[0], + { + "doc": "Extensions for my lab", + "source": "mylab.extensions.yaml", + "title": "Extensions for my lab", + }, + ) self.assertEqual(loaded_ns.version, "0.0.1") def test_get_source_files(self): ns_catalog = NamespaceCatalog() ns_catalog.load_namespaces(self.namespace_path, resolve=True) loaded_ns = ns_catalog.get_namespace(self.ns_name) - self.assertListEqual(loaded_ns.get_source_files(), ['mylab.extensions.yaml']) + self.assertListEqual(loaded_ns.get_source_files(), ["mylab.extensions.yaml"]) def test_get_source_description(self): ns_catalog = NamespaceCatalog() ns_catalog.load_namespaces(self.namespace_path, resolve=True) loaded_ns = ns_catalog.get_namespace(self.ns_name) - descr = loaded_ns.get_source_description('mylab.extensions.yaml') - self.assertDictEqual(descr, {'doc': 'Extensions for my lab', - 'source': 'mylab.extensions.yaml', - 'title': 'Extensions for my lab'}) + descr = loaded_ns.get_source_description("mylab.extensions.yaml") + self.assertDictEqual( + descr, + { + "doc": "Extensions for my lab", + "source": "mylab.extensions.yaml", + "title": "Extensions for my lab", + }, + ) def test_missing_version(self): """Test that creating a namespace builder without a version raises an error.""" msg = "Namespace '%s' missing key 'version'. Please specify a version for the extension." % self.ns_name with self.assertRaisesWith(ValueError, msg): - self.ns_builder = NamespaceBuilder(doc="mydoc", - name=self.ns_name, - full_name="My Laboratory", - author="foo", - contact="foo@bar.com", - namespace_cls=SpecNamespace, - date=self.date) + self.ns_builder = NamespaceBuilder( + doc="mydoc", + name=self.ns_name, + full_name="My Laboratory", + author="foo", + contact="foo@bar.com", + namespace_cls=SpecNamespace, + date=self.date, + ) class TestYAMLSpecWrite(TestSpec): - def setUp(self): super().setUp() for data_type in self.data_types: self.ns_builder.add_spec(source=self.ext_source_path, spec=data_type) - self.ns_builder.add_source(source=self.ext_source_path, - doc='Extensions for my lab', - title='My lab extensions') + self.ns_builder.add_source( + source=self.ext_source_path, + doc="Extensions for my lab", + title="My lab extensions", + ) def tearDown(self): if os.path.exists(self.ext_source_path): @@ -165,8 +183,8 @@ def tearDown(self): os.remove(self.namespace_path) def test_init(self): - temp = YAMLSpecWriter('.') - self.assertEqual(temp._YAMLSpecWriter__outdir, '.') + temp = YAMLSpecWriter(".") + self.assertEqual(temp._YAMLSpecWriter__outdir, ".") def test_write_namespace(self): temp = YAMLSpecWriter() @@ -179,10 +197,9 @@ def test_get_name(self): class TestExportSpec(TestSpec): - def test_export(self): """Test that export_spec writes the correct files.""" - export_spec(self.ns_builder, self.data_types, '.') + export_spec(self.ns_builder, self.data_types, ".") self._test_namespace_file() self._test_extensions_file() @@ -193,7 +210,7 @@ def tearDown(self): os.remove(self.namespace_path) def _test_namespace_file(self): - with open(self.namespace_path, 'r') as file: + with open(self.namespace_path, "r") as file: match_str = """namespaces: - author: foo contact: foo@bar.com @@ -204,11 +221,11 @@ def _test_namespace_file(self): schema: - source: mylab.extensions.yaml version: 0.0.1 -""" % self.date.isoformat() # noqa: E122 +""" % self.date.isoformat() nsstr = file.read() self.assertEqual(nsstr, match_str) def test_missing_data_types(self): """Test that calling export_spec on a namespace builder without data types raises a warning.""" - with self.assertWarnsWith(UserWarning, 'No data types specified. Exiting.'): - export_spec(self.ns_builder, [], '.') + with self.assertWarnsWith(UserWarning, "No data types specified. Exiting."): + export_spec(self.ns_builder, [], ".") diff --git a/tests/unit/test_container.py b/tests/unit/test_container.py index d0426c85a..cd2ddc021 100644 --- a/tests/unit/test_container.py +++ b/tests/unit/test_container.py @@ -1,8 +1,9 @@ +from uuid import UUID, uuid4 + import numpy as np -from uuid import uuid4, UUID -from hdmf.container import AbstractContainer, Container, Data, ExternalResourcesManager from hdmf.common.resources import ExternalResources +from hdmf.container import AbstractContainer, Container, Data, ExternalResourcesManager from hdmf.testing import TestCase from hdmf.utils import docval @@ -22,14 +23,16 @@ def test_link_and_get_resources(self): class TestContainer(TestCase): - def test_new(self): - """Test that __new__ properly sets parent and other fields. - """ - parent_obj = Container('obj1') + """Test that __new__ properly sets parent and other fields.""" + parent_obj = Container("obj1") child_object_id = str(uuid4()) - child_obj = Container.__new__(Container, parent=parent_obj, object_id=child_object_id, - container_source="test_source") + child_obj = Container.__new__( + Container, + parent=parent_obj, + object_id=child_object_id, + container_source="test_source", + ) self.assertIs(child_obj.parent, parent_obj) self.assertIs(parent_obj.children[0], child_obj) self.assertEqual(child_obj.object_id, child_object_id) @@ -37,77 +40,74 @@ def test_new(self): self.assertTrue(child_obj.modified) def test_new_object_id_none(self): - """Test that passing object_id=None to __new__ is OK and results in a non-None object ID being assigned. - """ - parent_obj = Container('obj1') + """Test that passing object_id=None to __new__ is OK and results in a non-None object ID being assigned.""" + parent_obj = Container("obj1") child_obj = Container.__new__(Container, parent=parent_obj, object_id=None) self.assertIsNotNone(child_obj.object_id) UUID(child_obj.object_id, version=4) # raises ValueError if invalid def test_new_construct_mode(self): - """Test that passing in_construct_mode to __new__ sets _in_construct_mode and _in_construct_mode can be reset. - """ - parent_obj = Container('obj1') + """Test that passing in_construct_mode to __new__ sets _in_construct_mode and that can be reset.""" + parent_obj = Container("obj1") child_obj = Container.__new__(Container, parent=parent_obj, object_id=None, in_construct_mode=True) self.assertTrue(child_obj._in_construct_mode) child_obj._in_construct_mode = False self.assertFalse(child_obj._in_construct_mode) def test_init(self): - """Test that __init__ properly sets object ID and other fields. - """ - obj = Container('obj1') + """Test that __init__ properly sets object ID and other fields.""" + obj = Container("obj1") self.assertIsNotNone(obj.object_id) UUID(obj.object_id, version=4) # raises ValueError if invalid self.assertFalse(obj._in_construct_mode) self.assertTrue(obj.modified) self.assertEqual(obj.children, tuple()) self.assertIsNone(obj.parent) - self.assertEqual(obj.name, 'obj1') + self.assertEqual(obj.name, "obj1") def test_set_parent(self): - """Test that parent setter properly sets parent - """ - parent_obj = Container('obj1') - child_obj = Container('obj2') + """Test that parent setter properly sets parent""" + parent_obj = Container("obj1") + child_obj = Container("obj2") child_obj.parent = parent_obj self.assertIs(child_obj.parent, parent_obj) self.assertIs(parent_obj.children[0], child_obj) def test_set_parent_overwrite(self): - """Test that parent setter properly blocks overwriting - """ - parent_obj = Container('obj1') - child_obj = Container('obj2') + """Test that parent setter properly blocks overwriting""" + parent_obj = Container("obj1") + child_obj = Container("obj2") child_obj.parent = parent_obj self.assertIs(parent_obj.children[0], child_obj) - another_obj = Container('obj3') - with self.assertRaisesWith(ValueError, - 'Cannot reassign parent to Container: %s. Parent is already: %s.' - % (repr(child_obj), repr(child_obj.parent))): + another_obj = Container("obj3") + with self.assertRaisesWith( + ValueError, + "Cannot reassign parent to Container: %s. Parent is already: %s." + % (repr(child_obj), repr(child_obj.parent)), + ): child_obj.parent = another_obj self.assertIs(child_obj.parent, parent_obj) self.assertIs(parent_obj.children[0], child_obj) def test_set_parent_overwrite_proxy(self): - """Test that parent setter properly blocks overwriting with proxy/object - """ - child_obj = Container('obj2') + """Test that parent setter properly blocks overwriting with proxy/object""" + child_obj = Container("obj2") child_obj.parent = object() - with self.assertRaisesRegex(ValueError, - r"Got None for parent of '[^/]+' - cannot overwrite Proxy with NoneType"): + with self.assertRaisesRegex( + ValueError, + r"Got None for parent of '[^/]+' - cannot overwrite Proxy with NoneType", + ): child_obj.parent = None def test_slash_restriction(self): - self.assertRaises(ValueError, Container, 'bad/name') + self.assertRaises(ValueError, Container, "bad/name") def test_set_modified_parent(self): - """Test that set modified properly sets parent modified - """ - parent_obj = Container('obj1') - child_obj = Container('obj2') + """Test that set modified properly sets parent modified""" + parent_obj = Container("obj1") + child_obj = Container("obj2") child_obj.parent = parent_obj parent_obj.set_modified(False) child_obj.set_modified(False) @@ -116,23 +116,24 @@ def test_set_modified_parent(self): self.assertTrue(child_obj.parent.modified) def test_add_child(self): - """Test that add child creates deprecation warning and also properly sets child's parent and modified - """ - parent_obj = Container('obj1') - child_obj = Container('obj2') + """Test that add child creates deprecation warning and also properly sets child's parent and modified""" + parent_obj = Container("obj1") + child_obj = Container("obj2") parent_obj.set_modified(False) - with self.assertWarnsWith(DeprecationWarning, 'add_child is deprecated. Set the parent attribute instead.'): + with self.assertWarnsWith( + DeprecationWarning, + "add_child is deprecated. Set the parent attribute instead.", + ): parent_obj.add_child(child_obj) self.assertIs(child_obj.parent, parent_obj) self.assertTrue(parent_obj.modified) self.assertIs(parent_obj.children[0], child_obj) def test_set_parent_exists(self): - """Test that setting a parent a second time does nothing - """ - parent_obj = Container('obj1') - child_obj = Container('obj2') - child_obj3 = Container('obj3') + """Test that setting a parent a second time does nothing""" + parent_obj = Container("obj1") + child_obj = Container("obj2") + child_obj3 = Container("obj3") child_obj.parent = parent_obj child_obj.parent = parent_obj child_obj3.parent = parent_obj @@ -141,25 +142,27 @@ def test_set_parent_exists(self): self.assertIs(parent_obj.children[1], child_obj3) def test_reassign_container_source(self): - """Test that reassign container source throws error - """ - parent_obj = Container('obj1') - parent_obj.container_source = 'a source' - with self.assertRaisesWith(Exception, 'cannot reassign container_source'): - parent_obj.container_source = 'some other source' + """Test that reassign container source throws error""" + parent_obj = Container("obj1") + parent_obj.container_source = "a source" + with self.assertRaisesWith(Exception, "cannot reassign container_source"): + parent_obj.container_source = "some other source" def test_repr(self): - parent_obj = Container('obj1') + parent_obj = Container("obj1") self.assertRegex(str(parent_obj), r"obj1 hdmf.container.Container at 0x\d+") def test_type_hierarchy(self): self.assertEqual(Container.type_hierarchy(), (Container, AbstractContainer, object)) - self.assertEqual(Subcontainer.type_hierarchy(), (Subcontainer, Container, AbstractContainer, object)) + self.assertEqual( + Subcontainer.type_hierarchy(), + (Subcontainer, Container, AbstractContainer, object), + ) def test_generate_new_id_parent(self): """Test that generate_new_id sets a new ID on the container and its children and sets modified on all.""" - parent_obj = Container('obj1') - child_obj = Container('obj2') + parent_obj = Container("obj1") + child_obj = Container("obj2") child_obj.parent = parent_obj old_parent_id = parent_obj.object_id old_child_id = child_obj.object_id @@ -174,8 +177,8 @@ def test_generate_new_id_parent(self): def test_generate_new_id_child(self): """Test that generate_new_id sets a new ID on the container and not its parent and sets modified on both.""" - parent_obj = Container('obj1') - child_obj = Container('obj2') + parent_obj = Container("obj1") + child_obj = Container("obj2") child_obj.parent = parent_obj old_parent_id = parent_obj.object_id old_child_id = child_obj.object_id @@ -190,8 +193,8 @@ def test_generate_new_id_child(self): def test_generate_new_id_parent_no_recurse(self): """Test that generate_new_id(recurse=False) sets a new ID on the container and not its children.""" - parent_obj = Container('obj1') - child_obj = Container('obj2') + parent_obj = Container("obj1") + child_obj = Container("obj2") child_obj.parent = parent_obj old_parent_id = parent_obj.object_id old_child_id = child_obj.object_id @@ -205,118 +208,110 @@ def test_generate_new_id_parent_no_recurse(self): self.assertFalse(child_obj.modified) def test_remove_child(self): - """Test that removing a child removes only the child. - """ - parent_obj = Container('obj1') - child_obj = Container('obj2') - child_obj3 = Container('obj3') + """Test that removing a child removes only the child.""" + parent_obj = Container("obj1") + child_obj = Container("obj2") + child_obj3 = Container("obj3") child_obj.parent = parent_obj child_obj3.parent = parent_obj parent_obj._remove_child(child_obj) self.assertIsNone(child_obj.parent) - self.assertTupleEqual(parent_obj.children, (child_obj3, )) + self.assertTupleEqual(parent_obj.children, (child_obj3,)) self.assertTrue(parent_obj.modified) self.assertTrue(child_obj.modified) def test_remove_child_noncontainer(self): - """Test that removing a non-Container child raises an error. - """ + """Test that removing a non-Container child raises an error.""" msg = "Cannot remove non-AbstractContainer object from children." with self.assertRaisesWith(ValueError, msg): - Container('obj1')._remove_child(object()) + Container("obj1")._remove_child(object()) def test_remove_child_nonchild(self): - """Test that removing a non-Container child raises an error. - """ + """Test that removing a non-Container child raises an error.""" msg = "Container 'dummy' is not a child of Container 'obj1'." with self.assertRaisesWith(ValueError, msg): - Container('obj1')._remove_child(Container('dummy')) + Container("obj1")._remove_child(Container("dummy")) def test_reset_parent(self): - """Test that removing a child removes only the child. - """ - parent_obj = Container('obj1') - child_obj = Container('obj2') - child_obj3 = Container('obj3') + """Test that removing a child removes only the child.""" + parent_obj = Container("obj1") + child_obj = Container("obj2") + child_obj3 = Container("obj3") child_obj.parent = parent_obj child_obj3.parent = parent_obj child_obj.reset_parent() self.assertIsNone(child_obj.parent) - self.assertTupleEqual(parent_obj.children, (child_obj3, )) + self.assertTupleEqual(parent_obj.children, (child_obj3,)) self.assertTrue(parent_obj.modified) self.assertTrue(child_obj.modified) def test_reset_parent_parent_noncontainer(self): - """Test that resetting a parent that is not a container raises an error. - """ - obj = Container('obj1') + """Test that resetting a parent that is not a container raises an error.""" + obj = Container("obj1") obj.parent = object() msg = "Cannot reset parent when parent is not an AbstractContainer: %s" % repr(obj.parent) with self.assertRaisesWith(ValueError, msg): obj.reset_parent() def test_reset_parent_no_parent(self): - """Test that resetting a non-existent parent has no effect. - """ - obj = Container('obj1') + """Test that resetting a non-existent parent has no effect.""" + obj = Container("obj1") obj.reset_parent() self.assertIsNone(obj.parent) class TestData(TestCase): - def test_constructor_scalar(self): - """Test that constructor works correctly on scalar data - """ - data_obj = Data('my_data', 'foobar') - self.assertEqual(data_obj.data, 'foobar') + """Test that constructor works correctly on scalar data""" + data_obj = Data("my_data", "foobar") + self.assertEqual(data_obj.data, "foobar") def test_bool_true(self): - """Test that __bool__ method works correctly on data with len - """ - data_obj = Data('my_data', [1, 2, 3, 4, 5]) + """Test that __bool__ method works correctly on data with len""" + data_obj = Data("my_data", [1, 2, 3, 4, 5]) self.assertTrue(data_obj) def test_bool_false(self): - """Test that __bool__ method works correctly on empty data - """ - data_obj = Data('my_data', []) + """Test that __bool__ method works correctly on empty data""" + data_obj = Data("my_data", []) self.assertFalse(data_obj) def test_shape_nparray(self): """ Test that shape works for np.array """ - data_obj = Data('my_data', np.arange(10).reshape(2, 5)) + data_obj = Data("my_data", np.arange(10).reshape(2, 5)) self.assertTupleEqual(data_obj.shape, (2, 5)) def test_shape_list(self): """ Test that shape works for np.array """ - data_obj = Data('my_data', [[0, 1, 2, 3, 4], [0, 1, 2, 3, 4]]) + data_obj = Data("my_data", [[0, 1, 2, 3, 4], [0, 1, 2, 3, 4]]) self.assertTupleEqual(data_obj.shape, (2, 5)) class TestAbstractContainerFieldsConf(TestCase): - def test_bad_fields_type(self): msg = "'__fields__' must be of type tuple" with self.assertRaisesWith(TypeError, msg): + class BadFieldsType(AbstractContainer): - __fields__ = {'name': 'field1'} + __fields__ = {"name": "field1"} def test_bad_field_conf_key(self): msg = "Unrecognized key 'child' in __fields__ config 'field1' on BadFieldConfKey" with self.assertRaisesWith(ValueError, msg): + class BadFieldConfKey(AbstractContainer): - __fields__ = ({'name': 'field1', 'child': True}, ) + __fields__ = ({"name": "field1", "child": True},) def test_bad_field_missing_name(self): msg = "must specify 'name' if using dict in __fields__" with self.assertRaisesWith(ValueError, msg): + class BadFieldConfKey(AbstractContainer): - __fields__ = ({'child': True}, ) + __fields__ = ({"child": True},) @staticmethod def find_all_properties(klass): @@ -331,46 +326,65 @@ class EmptyFields(AbstractContainer): self.assertTupleEqual(EmptyFields.get_fields_conf(), tuple()) props = TestAbstractContainerFieldsConf.find_all_properties(EmptyFields) - expected = ['children', 'container_source', 'fields', 'modified', 'name', 'object_id', 'parent'] + expected = [ + "children", + "container_source", + "fields", + "modified", + "name", + "object_id", + "parent", + ] self.assertListEqual(props, expected) def test_named_fields(self): class NamedFields(AbstractContainer): - __fields__ = ('field1', 'field2') + __fields__ = ("field1", "field2") - @docval({'name': 'field2', 'doc': 'field2 doc', 'type': str}) + @docval({"name": "field2", "doc": "field2 doc", "type": str}) def __init__(self, **kwargs): - super().__init__('test name') - self.field2 = kwargs['field2'] + super().__init__("test name") + self.field2 = kwargs["field2"] - self.assertTupleEqual(NamedFields.__fields__, ('field1', 'field2')) + self.assertTupleEqual(NamedFields.__fields__, ("field1", "field2")) self.assertIs(NamedFields._get_fields(), NamedFields.__fields__) - expected = ({'doc': None, 'name': 'field1'}, - {'doc': 'field2 doc', 'name': 'field2'}) + expected = ( + {"doc": None, "name": "field1"}, + {"doc": "field2 doc", "name": "field2"}, + ) self.assertTupleEqual(NamedFields.get_fields_conf(), expected) props = TestAbstractContainerFieldsConf.find_all_properties(NamedFields) - expected = ['children', 'container_source', 'field1', 'field2', 'fields', 'modified', 'name', 'object_id', - 'parent'] + expected = [ + "children", + "container_source", + "field1", + "field2", + "fields", + "modified", + "name", + "object_id", + "parent", + ] self.assertListEqual(props, expected) - f1_doc = getattr(NamedFields, 'field1').__doc__ + f1_doc = getattr(NamedFields, "field1").__doc__ self.assertIsNone(f1_doc) - f2_doc = getattr(NamedFields, 'field2').__doc__ - self.assertEqual(f2_doc, 'field2 doc') + f2_doc = getattr(NamedFields, "field2").__doc__ + self.assertEqual(f2_doc, "field2 doc") - obj = NamedFields('field2 value') + obj = NamedFields("field2 value") self.assertIsNone(obj.field1) - self.assertEqual(obj.field2, 'field2 value') + self.assertEqual(obj.field2, "field2 value") - obj.field1 = 'field1 value' + obj.field1 = "field1 value" msg = "can't set attribute 'field2' -- already set" with self.assertRaisesWith(AttributeError, msg): - obj.field2 = 'field2 value' + obj.field2 = "field2 value" obj.field2 = None # None value does nothing - self.assertEqual(obj.field2, 'field2 value') + self.assertEqual(obj.field2, "field2 value") def test_with_doc(self): """Test that __fields__ related attributes are set correctly. @@ -378,24 +392,29 @@ def test_with_doc(self): Also test that the docstring for fields are not overridden by the docstring in the docval of __init__ if a doc is provided in cls.__fields__. """ + class NamedFieldsWithDoc(AbstractContainer): - __fields__ = ({'name': 'field1', 'doc': 'field1 orig doc'}, - {'name': 'field2', 'doc': 'field2 orig doc'}) + __fields__ = ( + {"name": "field1", "doc": "field1 orig doc"}, + {"name": "field2", "doc": "field2 orig doc"}, + ) - @docval({'name': 'field2', 'doc': 'field2 doc', 'type': str}) + @docval({"name": "field2", "doc": "field2 doc", "type": str}) def __init__(self, **kwargs): - super().__init__('test name') - self.field2 = kwargs['field2'] + super().__init__("test name") + self.field2 = kwargs["field2"] - expected = ({'doc': 'field1 orig doc', 'name': 'field1'}, - {'doc': 'field2 orig doc', 'name': 'field2'}) + expected = ( + {"doc": "field1 orig doc", "name": "field1"}, + {"doc": "field2 orig doc", "name": "field2"}, + ) self.assertTupleEqual(NamedFieldsWithDoc.get_fields_conf(), expected) - f1_doc = getattr(NamedFieldsWithDoc, 'field1').__doc__ - self.assertEqual(f1_doc, 'field1 orig doc') + f1_doc = getattr(NamedFieldsWithDoc, "field1").__doc__ + self.assertEqual(f1_doc, "field1 orig doc") - f2_doc = getattr(NamedFieldsWithDoc, 'field2').__doc__ - self.assertEqual(f2_doc, 'field2 orig doc') + f2_doc = getattr(NamedFieldsWithDoc, "field2").__doc__ + self.assertEqual(f2_doc, "field2 orig doc") def test_not_settable(self): """Test that __fields__ related attributes are set correctly. @@ -403,49 +422,74 @@ def test_not_settable(self): Also test that the docstring for fields are not overridden by the docstring in the docval of __init__ if a doc is provided in cls.__fields__. """ - class NamedFieldsNotSettable(AbstractContainer): - __fields__ = ({'name': 'field1', 'settable': True}, - {'name': 'field2', 'settable': False}) - expected = ({'doc': None, 'name': 'field1', 'settable': True}, - {'doc': None, 'name': 'field2', 'settable': False}) + class NamedFieldsNotSettable(AbstractContainer): + __fields__ = ( + {"name": "field1", "settable": True}, + {"name": "field2", "settable": False}, + ) + + expected = ( + {"doc": None, "name": "field1", "settable": True}, + {"doc": None, "name": "field2", "settable": False}, + ) self.assertTupleEqual(NamedFieldsNotSettable.get_fields_conf(), expected) - obj = NamedFieldsNotSettable('test name') + obj = NamedFieldsNotSettable("test name") - obj.field1 = 'field1 value' + obj.field1 = "field1 value" with self.assertRaises(AttributeError): - obj.field2 = 'field2 value' + obj.field2 = "field2 value" def test_inheritance(self): class NamedFields(AbstractContainer): - __fields__ = ({'name': 'field1', 'doc': 'field1 doc', 'settable': False}, ) + __fields__ = ({"name": "field1", "doc": "field1 doc", "settable": False},) class NamedFieldsChild(NamedFields): - __fields__ = ({'name': 'field2'}, ) + __fields__ = ({"name": "field2"},) - self.assertTupleEqual(NamedFieldsChild.__fields__, ('field1', 'field2')) + self.assertTupleEqual(NamedFieldsChild.__fields__, ("field1", "field2")) self.assertIs(NamedFieldsChild._get_fields(), NamedFieldsChild.__fields__) - expected = ({'doc': 'field1 doc', 'name': 'field1', 'settable': False}, - {'doc': None, 'name': 'field2'}) + expected = ( + {"doc": "field1 doc", "name": "field1", "settable": False}, + {"doc": None, "name": "field2"}, + ) self.assertTupleEqual(NamedFieldsChild.get_fields_conf(), expected) props = TestAbstractContainerFieldsConf.find_all_properties(NamedFieldsChild) - expected = ['children', 'container_source', 'field1', 'field2', 'fields', 'modified', 'name', 'object_id', - 'parent'] + expected = [ + "children", + "container_source", + "field1", + "field2", + "fields", + "modified", + "name", + "object_id", + "parent", + ] self.assertListEqual(props, expected) def test_inheritance_override(self): class NamedFields(AbstractContainer): - __fields__ = ({'name': 'field1'}, ) + __fields__ = ({"name": "field1"},) class NamedFieldsChild(NamedFields): - __fields__ = ({'name': 'field1', 'doc': 'overridden field', 'settable': False}, ) - - self.assertEqual(NamedFieldsChild._get_fields(), ('field1', )) + __fields__ = ( + { + "name": "field1", + "doc": "overridden field", + "settable": False, + }, + ) + + self.assertEqual(NamedFieldsChild._get_fields(), ("field1",)) ret = NamedFieldsChild.get_fields_conf() - self.assertEqual(ret[0], {'name': 'field1', 'doc': 'overridden field', 'settable': False}) + self.assertEqual( + ret[0], + {"name": "field1", "doc": "overridden field", "settable": False}, + ) # obj = NamedFieldsChild('test name') # with self.assertRaises(AttributeError): @@ -453,53 +497,51 @@ class NamedFieldsChild(NamedFields): def test_mult_inheritance_base_mixin(self): class NamedFields(AbstractContainer): - __fields__ = ({'name': 'field1', 'doc': 'field1 doc', 'settable': False}, ) + __fields__ = ({"name": "field1", "doc": "field1 doc", "settable": False},) class BlankMixin: pass class NamedFieldsChild(NamedFields, BlankMixin): - __fields__ = ({'name': 'field2'}, ) + __fields__ = ({"name": "field2"},) - self.assertTupleEqual(NamedFieldsChild.__fields__, ('field1', 'field2')) + self.assertTupleEqual(NamedFieldsChild.__fields__, ("field1", "field2")) self.assertIs(NamedFieldsChild._get_fields(), NamedFieldsChild.__fields__) def test_mult_inheritance_base_container(self): class NamedFields(AbstractContainer): - __fields__ = ({'name': 'field1', 'doc': 'field1 doc', 'settable': False}, ) + __fields__ = ({"name": "field1", "doc": "field1 doc", "settable": False},) class BlankMixin: pass class NamedFieldsChild(BlankMixin, NamedFields): - __fields__ = ({'name': 'field2'}, ) + __fields__ = ({"name": "field2"},) - self.assertTupleEqual(NamedFieldsChild.__fields__, ('field1', 'field2')) + self.assertTupleEqual(NamedFieldsChild.__fields__, ("field1", "field2")) self.assertIs(NamedFieldsChild._get_fields(), NamedFieldsChild.__fields__) class TestContainerFieldsConf(TestCase): - def test_required_name(self): class ContainerRequiredName(Container): - __fields__ = ({'name': 'field1', 'required_name': 'field1 value'}, ) + __fields__ = ({"name": "field1", "required_name": "field1 value"},) - @docval({'name': 'field1', 'doc': 'field1 doc', 'type': None, 'default': None}) + @docval({"name": "field1", "doc": "field1 doc", "type": None, "default": None}) def __init__(self, **kwargs): - super().__init__('test name') - self.field1 = kwargs['field1'] + super().__init__("test name") + self.field1 = kwargs["field1"] - msg = ("Field 'field1' on ContainerRequiredName has a required name and must be a subclass of " - "AbstractContainer.") + msg = "Field 'field1' on ContainerRequiredName has a required name and must be a subclass of AbstractContainer." with self.assertRaisesWith(ValueError, msg): - ContainerRequiredName('field1 value') + ContainerRequiredName("field1 value") - obj1 = Container('test container') + obj1 = Container("test container") msg = "Field 'field1' on ContainerRequiredName must be named 'field1 value'." with self.assertRaisesWith(ValueError, msg): ContainerRequiredName(obj1) - obj2 = Container('field1 value') + obj2 = Container("field1 value") obj3 = ContainerRequiredName(obj2) self.assertIs(obj3.field1, obj2) @@ -508,24 +550,24 @@ def __init__(self, **kwargs): def test_child(self): class ContainerWithChild(Container): - __fields__ = ({'name': 'field1', 'child': True}, ) + __fields__ = ({"name": "field1", "child": True},) - @docval({'name': 'field1', 'doc': 'field1 doc', 'type': None, 'default': None}) + @docval({"name": "field1", "doc": "field1 doc", "type": None, "default": None}) def __init__(self, **kwargs): - super().__init__('test name') - self.field1 = kwargs['field1'] + super().__init__("test name") + self.field1 = kwargs["field1"] - child_obj1 = Container('test child 1') + child_obj1 = Container("test child 1") obj1 = ContainerWithChild(child_obj1) self.assertIs(child_obj1.parent, obj1) - child_obj2 = Container('test child 2') + child_obj2 = Container("test child 2") obj3 = ContainerWithChild((child_obj1, child_obj2)) self.assertIs(child_obj1.parent, obj1) # child1 parent is already set self.assertIs(child_obj2.parent, obj3) # child1 parent is already set - child_obj3 = Container('test child 3') - obj4 = ContainerWithChild({'test child 3': child_obj3}) + child_obj3 = Container("test child 3") + obj4 = ContainerWithChild({"test child 3": child_obj3}) self.assertIs(child_obj3.parent, obj4) obj2 = ContainerWithChild() @@ -533,14 +575,14 @@ def __init__(self, **kwargs): def test_setter_set_modified(self): class ContainerWithChild(Container): - __fields__ = ({'name': 'field1', 'child': True}, ) + __fields__ = ({"name": "field1", "child": True},) - @docval({'name': 'field1', 'doc': 'field1 doc', 'type': None, 'default': None}) + @docval({"name": "field1", "doc": "field1 doc", "type": None, "default": None}) def __init__(self, **kwargs): - super().__init__('test name') - self.field1 = kwargs['field1'] + super().__init__("test name") + self.field1 = kwargs["field1"] - child_obj1 = Container('test child 1') + child_obj1 = Container("test child 1") obj1 = ContainerWithChild() obj1.set_modified(False) # set to False so that we can test that it is set to True next obj1.field1 = child_obj1 @@ -555,45 +597,48 @@ def __init__(self, **kwargs): class TestChangeFieldsName(TestCase): - def test_fields(self): class ContainerNewFields(Container): - _fieldsname = '__newfields__' - __newfields__ = ({'name': 'field1', 'doc': 'field1 doc'}, ) + _fieldsname = "__newfields__" + __newfields__ = ({"name": "field1", "doc": "field1 doc"},) - @docval({'name': 'field1', 'doc': 'field1 doc', 'type': None, 'default': None}) + @docval({"name": "field1", "doc": "field1 doc", "type": None, "default": None}) def __init__(self, **kwargs): - super().__init__('test name') - self.field1 = kwargs['field1'] + super().__init__("test name") + self.field1 = kwargs["field1"] - self.assertTupleEqual(ContainerNewFields.__newfields__, ('field1', )) + self.assertTupleEqual(ContainerNewFields.__newfields__, ("field1",)) self.assertIs(ContainerNewFields._get_fields(), ContainerNewFields.__newfields__) - expected = ({'doc': 'field1 doc', 'name': 'field1'}, ) + expected = ({"doc": "field1 doc", "name": "field1"},) self.assertTupleEqual(ContainerNewFields.get_fields_conf(), expected) def test_fields_inheritance(self): class ContainerOldFields(Container): - __fields__ = ({'name': 'field1', 'doc': 'field1 doc'}, ) + __fields__ = ({"name": "field1", "doc": "field1 doc"},) - @docval({'name': 'field1', 'doc': 'field1 doc', 'type': None, 'default': None}) + @docval({"name": "field1", "doc": "field1 doc", "type": None, "default": None}) def __init__(self, **kwargs): - super().__init__('test name') - self.field1 = kwargs['field1'] + super().__init__("test name") + self.field1 = kwargs["field1"] class ContainerNewFields(ContainerOldFields): - _fieldsname = '__newfields__' - __newfields__ = ({'name': 'field2', 'doc': 'field2 doc'}, ) + _fieldsname = "__newfields__" + __newfields__ = ({"name": "field2", "doc": "field2 doc"},) - @docval({'name': 'field1', 'doc': 'field1 doc', 'type': None, 'default': None}, - {'name': 'field2', 'doc': 'field2 doc', 'type': None, 'default': None}) + @docval( + {"name": "field1", "doc": "field1 doc", "type": None, "default": None}, + {"name": "field2", "doc": "field2 doc", "type": None, "default": None}, + ) def __init__(self, **kwargs): - super().__init__(kwargs['field1']) - self.field2 = kwargs['field2'] + super().__init__(kwargs["field1"]) + self.field2 = kwargs["field2"] - self.assertTupleEqual(ContainerNewFields.__newfields__, ('field1', 'field2')) + self.assertTupleEqual(ContainerNewFields.__newfields__, ("field1", "field2")) self.assertIs(ContainerNewFields._get_fields(), ContainerNewFields.__newfields__) - expected = ({'doc': 'field1 doc', 'name': 'field1'}, - {'doc': 'field2 doc', 'name': 'field2'}, ) + expected = ( + {"doc": "field1 doc", "name": "field1"}, + {"doc": "field2 doc", "name": "field2"}, + ) self.assertTupleEqual(ContainerNewFields.get_fields_conf(), expected) diff --git a/tests/unit/test_io_hdf5.py b/tests/unit/test_io_hdf5.py index 0dae1fbbe..5122312c0 100644 --- a/tests/unit/test_io_hdf5.py +++ b/tests/unit/test_io_hdf5.py @@ -3,12 +3,14 @@ from numbers import Number import numpy as np -from h5py import File, Dataset, Reference +from h5py import Dataset, File, Reference + from hdmf.backends.hdf5 import HDF5IO -from hdmf.build import GroupBuilder, DatasetBuilder, LinkBuilder +from hdmf.build import DatasetBuilder, GroupBuilder, LinkBuilder from hdmf.testing import TestCase from hdmf.utils import get_data_shape -from tests.unit.helpers.utils import Foo, get_foo_buildmanager + +from .helpers.utils import Foo, get_foo_buildmanager class HDF5Encoder(json.JSONEncoder): @@ -33,12 +35,12 @@ def default(self, obj): class GroupBuilderTestCase(TestCase): - ''' + """ A TestCase class for comparing GroupBuilders. - ''' + """ def __is_scalar(self, obj): - if hasattr(obj, 'shape'): + if hasattr(obj, "shape"): return len(obj.shape) == 0 else: if any(isinstance(obj, t) for t in (int, str, float, bytes, str)): @@ -101,10 +103,10 @@ def __assert_helper(self, a, b): b_sub = b[k] b_keys.remove(k) if isinstance(a_sub, LinkBuilder) and isinstance(a_sub, LinkBuilder): - a_sub = a_sub['builder'] - b_sub = b_sub['builder'] + a_sub = a_sub["builder"] + b_sub = b_sub["builder"] elif isinstance(a_sub, LinkBuilder) != isinstance(a_sub, LinkBuilder): - reasons.append('%s != %s' % (a_sub, b_sub)) + reasons.append("%s != %s" % (a_sub, b_sub)) if isinstance(a_sub, DatasetBuilder) and isinstance(a_sub, DatasetBuilder): # if not self.__compare_dataset(a_sub, b_sub): # reasons.append('%s != %s' % (a_sub, b_sub)) @@ -120,16 +122,16 @@ def __assert_helper(self, a, b): elif a_array or b_array: # if strings, convert before comparing if b_array: - if b_sub.dtype.char in ('S', 'U'): + if b_sub.dtype.char in ("S", "U"): a_sub = [np.string_(s) for s in a_sub] else: - if a_sub.dtype.char in ('S', 'U'): + if a_sub.dtype.char in ("S", "U"): b_sub = [np.string_(s) for s in b_sub] equal = np.array_equal(a_sub, b_sub) else: equal = a_sub == b_sub if not equal: - reasons.append('%s != %s' % (self.__fmt(a_sub), self.__fmt(b_sub))) + reasons.append("%s != %s" % (self.__fmt(a_sub), self.__fmt(b_sub))) else: reasons.append("'%s' not in both" % k) for k in b_keys: @@ -137,83 +139,93 @@ def __assert_helper(self, a, b): return reasons def assertBuilderEqual(self, a, b): - ''' Tests that two GroupBuilders are equal ''' + """Tests that two GroupBuilders are equal""" reasons = self.__assert_helper(a, b) if len(reasons): - raise AssertionError(', '.join(reasons)) + raise AssertionError(", ".join(reasons)) return True class TestHDF5Writer(GroupBuilderTestCase): - def setUp(self): self.manager = get_foo_buildmanager() self.path = "test_io_hdf5.h5" - self.foo_builder = GroupBuilder('foo1', - attributes={'data_type': 'Foo', - 'namespace': 'test_core', - 'attr1': "bar", - 'object_id': -1}, - datasets={'my_data': DatasetBuilder('my_data', list(range(100, 200, 10)), - attributes={'attr2': 17})}) - self.foo = Foo('foo1', list(range(100, 200, 10)), attr1="bar", attr2=17, attr3=3.14) + self.foo_builder = GroupBuilder( + "foo1", + attributes={ + "data_type": "Foo", + "namespace": "test_core", + "attr1": "bar", + "object_id": -1, + }, + datasets={ + "my_data": DatasetBuilder( + "my_data", + list(range(100, 200, 10)), + attributes={"attr2": 17}, + ) + }, + ) + self.foo = Foo("foo1", list(range(100, 200, 10)), attr1="bar", attr2=17, attr3=3.14) self.manager.prebuilt(self.foo, self.foo_builder) self.builder = GroupBuilder( - 'root', + "root", source=self.path, - groups={'test_bucket': - GroupBuilder('test_bucket', - groups={'foo_holder': - GroupBuilder('foo_holder', - groups={'foo1': self.foo_builder})})}, - attributes={'data_type': 'FooFile'}) + groups={ + "test_bucket": GroupBuilder( + "test_bucket", + groups={"foo_holder": GroupBuilder("foo_holder", groups={"foo1": self.foo_builder})}, + ) + }, + attributes={"data_type": "FooFile"}, + ) def tearDown(self): if os.path.exists(self.path): os.remove(self.path) def check_fields(self): - f = File(self.path, 'r') - self.assertIn('test_bucket', f) - bucket = f.get('test_bucket') - self.assertIn('foo_holder', bucket) - holder = bucket.get('foo_holder') - self.assertIn('foo1', holder) + f = File(self.path, "r") + self.assertIn("test_bucket", f) + bucket = f.get("test_bucket") + self.assertIn("foo_holder", bucket) + holder = bucket.get("foo_holder") + self.assertIn("foo1", holder) return f def test_write_builder(self): - writer = HDF5IO(self.path, manager=self.manager, mode='a') + writer = HDF5IO(self.path, manager=self.manager, mode="a") writer.write_builder(self.builder) writer.close() self.check_fields() def test_write_attribute_reference_container(self): - writer = HDF5IO(self.path, manager=self.manager, mode='a') - self.builder.set_attribute('ref_attribute', self.foo) + writer = HDF5IO(self.path, manager=self.manager, mode="a") + self.builder.set_attribute("ref_attribute", self.foo) writer.write_builder(self.builder) writer.close() f = self.check_fields() - self.assertIsInstance(f.attrs['ref_attribute'], Reference) - self.assertEqual(f['test_bucket/foo_holder/foo1'], f[f.attrs['ref_attribute']]) + self.assertIsInstance(f.attrs["ref_attribute"], Reference) + self.assertEqual(f["test_bucket/foo_holder/foo1"], f[f.attrs["ref_attribute"]]) def test_write_attribute_reference_builder(self): - writer = HDF5IO(self.path, manager=self.manager, mode='a') - self.builder.set_attribute('ref_attribute', self.foo_builder) + writer = HDF5IO(self.path, manager=self.manager, mode="a") + self.builder.set_attribute("ref_attribute", self.foo_builder) writer.write_builder(self.builder) writer.close() f = self.check_fields() - self.assertIsInstance(f.attrs['ref_attribute'], Reference) - self.assertEqual(f['test_bucket/foo_holder/foo1'], f[f.attrs['ref_attribute']]) + self.assertIsInstance(f.attrs["ref_attribute"], Reference) + self.assertEqual(f["test_bucket/foo_holder/foo1"], f[f.attrs["ref_attribute"]]) def test_write_context_manager(self): - with HDF5IO(self.path, manager=self.manager, mode='a') as writer: + with HDF5IO(self.path, manager=self.manager, mode="a") as writer: writer.write_builder(self.builder) self.check_fields() def test_read_builder(self): self.maxDiff = None - io = HDF5IO(self.path, manager=self.manager, mode='a') + io = HDF5IO(self.path, manager=self.manager, mode="a") io.write_builder(self.builder) builder = io.read_builder() self.assertBuilderEqual(builder, self.builder) @@ -221,9 +233,9 @@ def test_read_builder(self): def test_dataset_shape(self): self.maxDiff = None - io = HDF5IO(self.path, manager=self.manager, mode='a') + io = HDF5IO(self.path, manager=self.manager, mode="a") io.write_builder(self.builder) builder = io.read_builder() - dset = builder['test_bucket']['foo_holder']['foo1']['my_data'].data + dset = builder["test_bucket"]["foo_holder"]["foo1"]["my_data"].data self.assertEqual(get_data_shape(dset), (10,)) io.close() diff --git a/tests/unit/test_io_hdf5_h5tools.py b/tests/unit/test_io_hdf5_h5tools.py index 5a7798d26..c7eda5965 100644 --- a/tests/unit/test_io_hdf5_h5tools.py +++ b/tests/unit/test_io_hdf5_h5tools.py @@ -1,36 +1,62 @@ """Test module to validate that HDF5IO is working""" import os +import shutil +import tempfile import unittest import warnings from io import BytesIO from pathlib import Path -import shutil -import tempfile import h5py import numpy as np -from h5py import SoftLink, HardLink, ExternalLink, File +from h5py import ExternalLink, File, HardLink, SoftLink from h5py import filters as h5py_filters + +from hdmf.backends.errors import UnsupportedOperation from hdmf.backends.hdf5 import H5DataIO -from hdmf.backends.hdf5.h5tools import HDF5IO, SPEC_LOC_ATTR, H5PY_3 +from hdmf.backends.hdf5.h5tools import H5PY_3, HDF5IO, SPEC_LOC_ATTR from hdmf.backends.io import HDMFIO from hdmf.backends.warnings import BrokenLinkWarning -from hdmf.backends.errors import UnsupportedOperation -from hdmf.build import GroupBuilder, DatasetBuilder, BuildManager, TypeMap, OrphanContainerBuildError, LinkBuilder +from hdmf.build import ( + BuildManager, + DatasetBuilder, + GroupBuilder, + LinkBuilder, + OrphanContainerBuildError, + TypeMap, +) from hdmf.container import Container -from hdmf.data_utils import DataChunkIterator, GenericDataChunkIterator, InvalidDataIOError +from hdmf.data_utils import ( + DataChunkIterator, + GenericDataChunkIterator, + InvalidDataIOError, +) from hdmf.spec.catalog import SpecCatalog from hdmf.spec.namespace import NamespaceCatalog, SpecNamespace from hdmf.spec.spec import GroupSpec from hdmf.testing import TestCase -from tests.unit.helpers.utils import (Foo, FooBucket, FooFile, get_foo_buildmanager, - Baz, BazData, BazCpdData, BazBucket, get_baz_buildmanager, - CORE_NAMESPACE, get_temp_filepath, CacheSpecTestHelper, - CustomGroupSpec, CustomDatasetSpec, CustomSpecNamespace) +from .helpers.utils import ( + CORE_NAMESPACE, + Baz, + BazBucket, + BazCpdData, + BazData, + CacheSpecTestHelper, + CustomDatasetSpec, + CustomGroupSpec, + CustomSpecNamespace, + Foo, + FooBucket, + FooFile, + get_baz_buildmanager, + get_foo_buildmanager, + get_temp_filepath, +) try: import zarr + SKIP_ZARR_TESTS = False except ImportError: SKIP_ZARR_TESTS = True @@ -56,7 +82,7 @@ class H5IOTest(TestCase): def setUp(self): self.path = get_temp_filepath() - self.io = HDF5IO(self.path, mode='a') + self.io = HDF5IO(self.path, mode="a") self.f = self.io._file def tearDown(self): @@ -68,44 +94,76 @@ def tearDown(self): ########################################## def test__chunked_iter_fill(self): """Matrix test of HDF5IO.__chunked_iter_fill__ using a DataChunkIterator with different parameters""" - data_opts = {'iterator': range(10), - 'numpy': np.arange(30).reshape(5, 2, 3), - 'list': np.arange(30).reshape(5, 2, 3).tolist(), - 'sparselist1': [1, 2, 3, None, None, None, None, 8, 9, 10], - 'sparselist2': [None, None, 3], - 'sparselist3': [1, 2, 3, None, None], # note: cannot process None in ndarray - 'nanlist': [[[1, 2, 3, np.nan, np.nan, 6], [np.nan, np.nan, 3, 4, np.nan, np.nan]], - [[10, 20, 30, 40, np.nan, np.nan], [np.nan, np.nan, np.nan, np.nan, np.nan, np.nan]]]} - buffer_size_opts = [1, 2, 3, 4] # data is divisible by some of these, some not + data_opts = { + "iterator": range(10), + "numpy": np.arange(30).reshape(5, 2, 3), + "list": np.arange(30).reshape(5, 2, 3).tolist(), + "sparselist1": [1, 2, 3, None, None, None, None, 8, 9, 10], + "sparselist2": [None, None, 3], + "sparselist3": [ + 1, + 2, + 3, + None, + None, + ], # note: cannot process None in ndarray + "nanlist": [ + [ + [1, 2, 3, np.nan, np.nan, 6], + [np.nan, np.nan, 3, 4, np.nan, np.nan], + ], + [ + [10, 20, 30, 40, np.nan, np.nan], + [np.nan, np.nan, np.nan, np.nan, np.nan, np.nan], + ], + ], + } + buffer_size_opts = [ + 1, + 2, + 3, + 4, + ] # data is divisible by some of these, some not for data_type, data in data_opts.items(): iter_axis_opts = [0, 1, 2] - if data_type == 'iterator' or data_type.startswith('sparselist'): + if data_type == "iterator" or data_type.startswith("sparselist"): iter_axis_opts = [0] # only one dimension for iter_axis in iter_axis_opts: for buffer_size in buffer_size_opts: - with self.subTest(data_type=data_type, iter_axis=iter_axis, buffer_size=buffer_size): + with self.subTest( + data_type=data_type, + iter_axis=iter_axis, + buffer_size=buffer_size, + ): with warnings.catch_warnings(record=True): # init may throw UserWarning for iterating over not-first dim of a list. ignore here - msg = ("Iterating over an axis other than the first dimension of list or tuple data " - "involves converting the data object to a numpy ndarray, which may incur a " - "computational cost.") + msg = ( + "Iterating over an axis other than the first dimension" + " of list or tuple data involves converting the data" + " object to a numpy ndarray, which may incur a" + " computational cost." + ) warnings.filterwarnings("ignore", message=msg, category=UserWarning) - dci = DataChunkIterator(data=data, buffer_size=buffer_size, iter_axis=iter_axis) + dci = DataChunkIterator( + data=data, + buffer_size=buffer_size, + iter_axis=iter_axis, + ) - dset_name = '%s, %d, %d' % (data_type, iter_axis, buffer_size) + dset_name = "%s, %d, %d" % (data_type, iter_axis, buffer_size) my_dset = HDF5IO.__chunked_iter_fill__(self.f, dset_name, dci) - if data_type == 'iterator': + if data_type == "iterator": self.assertListEqual(my_dset[:].tolist(), list(data)) - elif data_type == 'numpy': + elif data_type == "numpy": self.assertTrue(np.all(my_dset[:] == data)) self.assertTupleEqual(my_dset.shape, data.shape) - elif data_type == 'list' or data_type == 'nanlist': + elif data_type == "list" or data_type == "nanlist": data_np = np.array(data) np.testing.assert_array_equal(my_dset[:], data_np) self.assertTupleEqual(my_dset.shape, data_np.shape) - elif data_type.startswith('sparselist'): + elif data_type.startswith("sparselist"): # replace None in original data with default hdf5 fillvalue 0 data_zeros = np.where(np.equal(np.array(data), None), 0, data) np.testing.assert_array_equal(my_dset[:], data_zeros) @@ -116,20 +174,20 @@ def test__chunked_iter_fill(self): ########################################## def test_write_dataset_scalar(self): a = 10 - self.io.write_dataset(self.f, DatasetBuilder('test_dataset', a, attributes={})) - dset = self.f['test_dataset'] + self.io.write_dataset(self.f, DatasetBuilder("test_dataset", a, attributes={})) + dset = self.f["test_dataset"] self.assertTupleEqual(dset.shape, ()) self.assertEqual(dset[()], a) def test_write_dataset_string(self): - a = 'test string' - self.io.write_dataset(self.f, DatasetBuilder('test_dataset', a, attributes={})) - dset = self.f['test_dataset'] + a = "test string" + self.io.write_dataset(self.f, DatasetBuilder("test_dataset", a, attributes={})) + dset = self.f["test_dataset"] self.assertTupleEqual(dset.shape, ()) # self.assertEqual(dset[()].decode('utf-8'), a) read_a = dset[()] if isinstance(read_a, bytes): - read_a = read_a.decode('utf-8') + read_a = read_a.decode("utf-8") self.assertEqual(read_a, a) ########################################## @@ -137,108 +195,125 @@ def test_write_dataset_string(self): ########################################## def test_write_dataset_list(self): a = np.arange(30).reshape(5, 2, 3) - self.io.write_dataset(self.f, DatasetBuilder('test_dataset', a.tolist(), attributes={})) - dset = self.f['test_dataset'] + self.io.write_dataset(self.f, DatasetBuilder("test_dataset", a.tolist(), attributes={})) + dset = self.f["test_dataset"] self.assertTrue(np.all(dset[:] == a)) def test_write_dataset_list_compress_gzip(self): - a = H5DataIO(np.arange(30).reshape(5, 2, 3), - compression='gzip', - compression_opts=5, - shuffle=True, - fletcher32=True) - self.io.write_dataset(self.f, DatasetBuilder('test_dataset', a, attributes={})) - dset = self.f['test_dataset'] + a = H5DataIO( + np.arange(30).reshape(5, 2, 3), + compression="gzip", + compression_opts=5, + shuffle=True, + fletcher32=True, + ) + self.io.write_dataset(self.f, DatasetBuilder("test_dataset", a, attributes={})) + dset = self.f["test_dataset"] self.assertTrue(np.all(dset[:] == a.data)) - self.assertEqual(dset.compression, 'gzip') + self.assertEqual(dset.compression, "gzip") self.assertEqual(dset.compression_opts, 5) self.assertEqual(dset.shuffle, True) self.assertEqual(dset.fletcher32, True) - @unittest.skipIf("lzf" not in h5py_filters.encode, - "LZF compression not supported in this h5py library install") + @unittest.skipIf( + "lzf" not in h5py_filters.encode, + "LZF compression not supported in this h5py library install", + ) def test_write_dataset_list_compress_lzf(self): - warn_msg = ("lzf compression may not be available on all installations of HDF5. Use of gzip is " - "recommended to ensure portability of the generated HDF5 files.") + warn_msg = ( + "lzf compression may not be available on all installations of HDF5. Use of" + " gzip is recommended to ensure portability of the generated HDF5 files." + ) with self.assertWarnsWith(UserWarning, warn_msg): - a = H5DataIO(np.arange(30).reshape(5, 2, 3), - compression='lzf', - shuffle=True, - fletcher32=True) - self.io.write_dataset(self.f, DatasetBuilder('test_dataset', a, attributes={})) - dset = self.f['test_dataset'] + a = H5DataIO( + np.arange(30).reshape(5, 2, 3), + compression="lzf", + shuffle=True, + fletcher32=True, + ) + self.io.write_dataset(self.f, DatasetBuilder("test_dataset", a, attributes={})) + dset = self.f["test_dataset"] self.assertTrue(np.all(dset[:] == a.data)) - self.assertEqual(dset.compression, 'lzf') + self.assertEqual(dset.compression, "lzf") self.assertEqual(dset.shuffle, True) self.assertEqual(dset.fletcher32, True) - @unittest.skipIf("szip" not in h5py_filters.encode, - "SZIP compression not supported in this h5py library install") + @unittest.skipIf( + "szip" not in h5py_filters.encode, + "SZIP compression not supported in this h5py library install", + ) def test_write_dataset_list_compress_szip(self): - warn_msg = ("szip compression may not be available on all installations of HDF5. Use of gzip is " - "recommended to ensure portability of the generated HDF5 files.") + warn_msg = ( + "szip compression may not be available on all installations of HDF5. Use of" + " gzip is recommended to ensure portability of the generated HDF5 files." + ) with self.assertWarnsWith(UserWarning, warn_msg): - a = H5DataIO(np.arange(30).reshape(5, 2, 3), - compression='szip', - compression_opts=('ec', 16), - shuffle=True, - fletcher32=True) - self.io.write_dataset(self.f, DatasetBuilder('test_dataset', a, attributes={})) - dset = self.f['test_dataset'] + a = H5DataIO( + np.arange(30).reshape(5, 2, 3), + compression="szip", + compression_opts=("ec", 16), + shuffle=True, + fletcher32=True, + ) + self.io.write_dataset(self.f, DatasetBuilder("test_dataset", a, attributes={})) + dset = self.f["test_dataset"] self.assertTrue(np.all(dset[:] == a.data)) - self.assertEqual(dset.compression, 'szip') + self.assertEqual(dset.compression, "szip") self.assertEqual(dset.shuffle, True) self.assertEqual(dset.fletcher32, True) def test_write_dataset_list_compress_available_int_filters(self): - a = H5DataIO(np.arange(30).reshape(5, 2, 3), - compression=1, - shuffle=True, - fletcher32=True, - allow_plugin_filters=True) - self.io.write_dataset(self.f, DatasetBuilder('test_dataset', a, attributes={})) - dset = self.f['test_dataset'] + a = H5DataIO( + np.arange(30).reshape(5, 2, 3), + compression=1, + shuffle=True, + fletcher32=True, + allow_plugin_filters=True, + ) + self.io.write_dataset(self.f, DatasetBuilder("test_dataset", a, attributes={})) + dset = self.f["test_dataset"] self.assertTrue(np.all(dset[:] == a.data)) - self.assertEqual(dset.compression, 'gzip') + self.assertEqual(dset.compression, "gzip") self.assertEqual(dset.shuffle, True) self.assertEqual(dset.fletcher32, True) def test_write_dataset_list_enable_default_compress(self): - a = H5DataIO(np.arange(30).reshape(5, 2, 3), - compression=True) - self.assertEqual(a.io_settings['compression'], 'gzip') - self.io.write_dataset(self.f, DatasetBuilder('test_dataset', a, attributes={})) - dset = self.f['test_dataset'] + a = H5DataIO(np.arange(30).reshape(5, 2, 3), compression=True) + self.assertEqual(a.io_settings["compression"], "gzip") + self.io.write_dataset(self.f, DatasetBuilder("test_dataset", a, attributes={})) + dset = self.f["test_dataset"] self.assertTrue(np.all(dset[:] == a.data)) - self.assertEqual(dset.compression, 'gzip') + self.assertEqual(dset.compression, "gzip") def test_write_dataset_list_disable_default_compress(self): - msg = ("Compression disabled by compression=False setting. compression_opts parameter will, therefore, " - "be ignored.") + msg = ( + "Compression disabled by compression=False setting. compression_opts parameter will, therefore, be ignored." + ) with self.assertWarnsWith(UserWarning, msg): - a = H5DataIO(np.arange(30).reshape(5, 2, 3), - compression=False, - compression_opts=5) - self.assertFalse('compression_ops' in a.io_settings) - self.assertFalse('compression' in a.io_settings) - - self.io.write_dataset(self.f, DatasetBuilder('test_dataset', a, attributes={})) - dset = self.f['test_dataset'] + a = H5DataIO( + np.arange(30).reshape(5, 2, 3), + compression=False, + compression_opts=5, + ) + self.assertFalse("compression_ops" in a.io_settings) + self.assertFalse("compression" in a.io_settings) + + self.io.write_dataset(self.f, DatasetBuilder("test_dataset", a, attributes={})) + dset = self.f["test_dataset"] self.assertTrue(np.all(dset[:] == a.data)) self.assertEqual(dset.compression, None) def test_write_dataset_list_chunked(self): - a = H5DataIO(np.arange(30).reshape(5, 2, 3), - chunks=(1, 1, 3)) - self.io.write_dataset(self.f, DatasetBuilder('test_dataset', a, attributes={})) - dset = self.f['test_dataset'] + a = H5DataIO(np.arange(30).reshape(5, 2, 3), chunks=(1, 1, 3)) + self.io.write_dataset(self.f, DatasetBuilder("test_dataset", a, attributes={})) + dset = self.f["test_dataset"] self.assertTrue(np.all(dset[:] == a.data)) self.assertEqual(dset.chunks, (1, 1, 3)) def test_write_dataset_list_fillvalue(self): a = H5DataIO(np.arange(20).reshape(5, 4), fillvalue=-1) - self.io.write_dataset(self.f, DatasetBuilder('test_dataset', a, attributes={})) - dset = self.f['test_dataset'] + self.io.write_dataset(self.f, DatasetBuilder("test_dataset", a, attributes={})) + dset = self.f["test_dataset"] self.assertTrue(np.all(dset[:] == a.data)) self.assertEqual(dset.fillvalue, -1) @@ -246,47 +321,59 @@ def test_write_dataset_list_fillvalue(self): # write_dataset tests: tables ########################################## def test_write_table(self): - cmpd_dt = np.dtype([('a', np.int32), ('b', np.float64)]) + cmpd_dt = np.dtype([("a", np.int32), ("b", np.float64)]) data = np.zeros(10, dtype=cmpd_dt) - data['a'][1] = 101 - data['b'][1] = 0.1 - dt = [{'name': 'a', 'dtype': 'int32', 'doc': 'a column'}, - {'name': 'b', 'dtype': 'float64', 'doc': 'b column'}] - self.io.write_dataset(self.f, DatasetBuilder('test_dataset', data, attributes={}, dtype=dt)) - dset = self.f['test_dataset'] - self.assertEqual(dset['a'].tolist(), data['a'].tolist()) - self.assertEqual(dset['b'].tolist(), data['b'].tolist()) + data["a"][1] = 101 + data["b"][1] = 0.1 + dt = [ + {"name": "a", "dtype": "int32", "doc": "a column"}, + {"name": "b", "dtype": "float64", "doc": "b column"}, + ] + self.io.write_dataset( + self.f, + DatasetBuilder("test_dataset", data, attributes={}, dtype=dt), + ) + dset = self.f["test_dataset"] + self.assertEqual(dset["a"].tolist(), data["a"].tolist()) + self.assertEqual(dset["b"].tolist(), data["b"].tolist()) def test_write_table_nested(self): - b_cmpd_dt = np.dtype([('c', np.int32), ('d', np.float64)]) - cmpd_dt = np.dtype([('a', np.int32), ('b', b_cmpd_dt)]) + b_cmpd_dt = np.dtype([("c", np.int32), ("d", np.float64)]) + cmpd_dt = np.dtype([("a", np.int32), ("b", b_cmpd_dt)]) data = np.zeros(10, dtype=cmpd_dt) - data['a'][1] = 101 - data['b']['c'] = 202 - data['b']['d'] = 10.1 - b_dt = [{'name': 'c', 'dtype': 'int32', 'doc': 'c column'}, - {'name': 'd', 'dtype': 'float64', 'doc': 'd column'}] - dt = [{'name': 'a', 'dtype': 'int32', 'doc': 'a column'}, - {'name': 'b', 'dtype': b_dt, 'doc': 'b column'}] - self.io.write_dataset(self.f, DatasetBuilder('test_dataset', data, attributes={}, dtype=dt)) - dset = self.f['test_dataset'] - self.assertEqual(dset['a'].tolist(), data['a'].tolist()) - self.assertEqual(dset['b'].tolist(), data['b'].tolist()) + data["a"][1] = 101 + data["b"]["c"] = 202 + data["b"]["d"] = 10.1 + b_dt = [ + {"name": "c", "dtype": "int32", "doc": "c column"}, + {"name": "d", "dtype": "float64", "doc": "d column"}, + ] + dt = [ + {"name": "a", "dtype": "int32", "doc": "a column"}, + {"name": "b", "dtype": b_dt, "doc": "b column"}, + ] + self.io.write_dataset( + self.f, + DatasetBuilder("test_dataset", data, attributes={}, dtype=dt), + ) + dset = self.f["test_dataset"] + self.assertEqual(dset["a"].tolist(), data["a"].tolist()) + self.assertEqual(dset["b"].tolist(), data["b"].tolist()) ########################################## # write_dataset tests: Iterable ########################################## def test_write_dataset_iterable(self): - self.io.write_dataset(self.f, DatasetBuilder('test_dataset', range(10), attributes={})) - dset = self.f['test_dataset'] + self.io.write_dataset(self.f, DatasetBuilder("test_dataset", range(10), attributes={})) + dset = self.f["test_dataset"] self.assertListEqual(dset[:].tolist(), list(range(10))) def test_write_dataset_iterable_multidimensional_array(self): a = np.arange(30).reshape(5, 2, 3) aiter = iter(a) daiter = DataChunkIterator.from_iterable(aiter, buffer_size=2) - self.io.write_dataset(self.f, DatasetBuilder('test_dataset', daiter, attributes={})) - dset = self.f['test_dataset'] + self.io.write_dataset(self.f, DatasetBuilder("test_dataset", daiter, attributes={})) + dset = self.f["test_dataset"] self.assertListEqual(dset[:].tolist(), a.tolist()) def test_write_multi_dci_oaat(self): @@ -300,14 +387,14 @@ def test_write_multi_dci_oaat(self): daiter1 = DataChunkIterator.from_iterable(aiter, buffer_size=2) daiter2 = DataChunkIterator.from_iterable(biter, buffer_size=2) builder = GroupBuilder("root") - dataset1 = DatasetBuilder('test_dataset1', daiter1) - dataset2 = DatasetBuilder('test_dataset2', daiter2) + dataset1 = DatasetBuilder("test_dataset1", daiter1) + dataset2 = DatasetBuilder("test_dataset2", daiter2) builder.set_dataset(dataset1) builder.set_dataset(dataset2) self.io.write_builder(builder) - dset1 = self.f['test_dataset1'] + dset1 = self.f["test_dataset1"] self.assertListEqual(dset1[:].tolist(), a.tolist()) - dset2 = self.f['test_dataset2'] + dset2 = self.f["test_dataset2"] self.assertListEqual(dset2[:].tolist(), b.tolist()) def test_write_multi_dci_conc(self): @@ -321,30 +408,35 @@ def test_write_multi_dci_conc(self): daiter1 = DataChunkIterator.from_iterable(aiter, buffer_size=2) daiter2 = DataChunkIterator.from_iterable(biter, buffer_size=2) builder = GroupBuilder("root") - dataset1 = DatasetBuilder('test_dataset1', daiter1) - dataset2 = DatasetBuilder('test_dataset2', daiter2) + dataset1 = DatasetBuilder("test_dataset1", daiter1) + dataset2 = DatasetBuilder("test_dataset2", daiter2) builder.set_dataset(dataset1) builder.set_dataset(dataset2) self.io.write_builder(builder) - dset1 = self.f['test_dataset1'] + dset1 = self.f["test_dataset1"] self.assertListEqual(dset1[:].tolist(), a.tolist()) - dset2 = self.f['test_dataset2'] + dset2 = self.f["test_dataset2"] self.assertListEqual(dset2[:].tolist(), b.tolist()) def test_write_dataset_iterable_multidimensional_array_compression(self): a = np.arange(30).reshape(5, 2, 3) aiter = iter(a) daiter = DataChunkIterator.from_iterable(aiter, buffer_size=2) - wrapped_daiter = H5DataIO(data=daiter, - compression='gzip', - compression_opts=5, - shuffle=True, - fletcher32=True) - self.io.write_dataset(self.f, DatasetBuilder('test_dataset', wrapped_daiter, attributes={})) - dset = self.f['test_dataset'] + wrapped_daiter = H5DataIO( + data=daiter, + compression="gzip", + compression_opts=5, + shuffle=True, + fletcher32=True, + ) + self.io.write_dataset( + self.f, + DatasetBuilder("test_dataset", wrapped_daiter, attributes={}), + ) + dset = self.f["test_dataset"] self.assertEqual(dset.shape, a.shape) self.assertListEqual(dset[:].tolist(), a.tolist()) - self.assertEqual(dset.compression, 'gzip') + self.assertEqual(dset.compression, "gzip") self.assertEqual(dset.compression_opts, 5) self.assertEqual(dset.shuffle, True) self.assertEqual(dset.fletcher32, True) @@ -354,44 +446,50 @@ def test_write_dataset_iterable_multidimensional_array_compression(self): ############################################# def test_write_dataset_data_chunk_iterator(self): dci = DataChunkIterator(data=np.arange(10), buffer_size=2) - self.io.write_dataset(self.f, DatasetBuilder('test_dataset', dci, attributes={}, dtype=dci.dtype)) - dset = self.f['test_dataset'] + self.io.write_dataset( + self.f, + DatasetBuilder("test_dataset", dci, attributes={}, dtype=dci.dtype), + ) + dset = self.f["test_dataset"] self.assertListEqual(dset[:].tolist(), list(range(10))) self.assertEqual(dset[:].dtype, dci.dtype) def test_write_dataset_data_chunk_iterator_with_compression(self): dci = DataChunkIterator(data=np.arange(10), buffer_size=2) - wrapped_dci = H5DataIO(data=dci, - compression='gzip', - compression_opts=5, - shuffle=True, - fletcher32=True, - chunks=(2,)) - self.io.write_dataset(self.f, DatasetBuilder('test_dataset', wrapped_dci, attributes={})) - dset = self.f['test_dataset'] + wrapped_dci = H5DataIO( + data=dci, + compression="gzip", + compression_opts=5, + shuffle=True, + fletcher32=True, + chunks=(2,), + ) + self.io.write_dataset(self.f, DatasetBuilder("test_dataset", wrapped_dci, attributes={})) + dset = self.f["test_dataset"] self.assertListEqual(dset[:].tolist(), list(range(10))) - self.assertEqual(dset.compression, 'gzip') + self.assertEqual(dset.compression, "gzip") self.assertEqual(dset.compression_opts, 5) self.assertEqual(dset.shuffle, True) self.assertEqual(dset.fletcher32, True) self.assertEqual(dset.chunks, (2,)) def test_pass_through_of_recommended_chunks(self): - class DC(DataChunkIterator): def recommended_chunk_shape(self): return (5, 1, 1) dci = DC(data=np.arange(30).reshape(5, 2, 3)) - wrapped_dci = H5DataIO(data=dci, - compression='gzip', - compression_opts=5, - shuffle=True, - fletcher32=True) - self.io.write_dataset(self.f, DatasetBuilder('test_dataset', wrapped_dci, attributes={})) - dset = self.f['test_dataset'] + wrapped_dci = H5DataIO( + data=dci, + compression="gzip", + compression_opts=5, + shuffle=True, + fletcher32=True, + ) + self.io.write_dataset(self.f, DatasetBuilder("test_dataset", wrapped_dci, attributes={})) + dset = self.f["test_dataset"] self.assertEqual(dset.chunks, (5, 1, 1)) - self.assertEqual(dset.compression, 'gzip') + self.assertEqual(dset.compression, "gzip") self.assertEqual(dset.compression_opts, 5) self.assertEqual(dset.shuffle, True) self.assertEqual(dset.fletcher32, True) @@ -399,8 +497,8 @@ def recommended_chunk_shape(self): def test_dci_h5dataset(self): data = np.arange(30).reshape(5, 2, 3) dci1 = DataChunkIterator(data=data, buffer_size=1, iter_axis=0) - HDF5IO.__chunked_iter_fill__(self.f, 'test_dataset', dci1) - dset = self.f['test_dataset'] + HDF5IO.__chunked_iter_fill__(self.f, "test_dataset", dci1) + dset = self.f["test_dataset"] dci2 = DataChunkIterator(data=dset, buffer_size=2, iter_axis=2) chunk = dci2.next() @@ -416,8 +514,8 @@ def test_dci_h5dataset(self): def test_dci_h5dataset_sparse_matched(self): data = [1, 2, 3, None, None, None, None, 8, 9, 10] dci1 = DataChunkIterator(data=data, buffer_size=3) - HDF5IO.__chunked_iter_fill__(self.f, 'test_dataset', dci1) - dset = self.f['test_dataset'] + HDF5IO.__chunked_iter_fill__(self.f, "test_dataset", dci1) + dset = self.f["test_dataset"] dci2 = DataChunkIterator(data=dset, buffer_size=2) # dataset is read such that Nones in original data were not written, but are read as 0s @@ -450,8 +548,8 @@ def test_dci_h5dataset_sparse_matched(self): def test_dci_h5dataset_sparse_unmatched(self): data = [1, 2, 3, None, None, None, None, 8, 9, 10] dci1 = DataChunkIterator(data=data, buffer_size=3) - HDF5IO.__chunked_iter_fill__(self.f, 'test_dataset', dci1) - dset = self.f['test_dataset'] + HDF5IO.__chunked_iter_fill__(self.f, "test_dataset", dci1) + dset = self.f["test_dataset"] dci2 = DataChunkIterator(data=dset, buffer_size=4) # dataset is read such that Nones in original data were not written, but are read as 0s @@ -478,8 +576,8 @@ def test_dci_h5dataset_sparse_unmatched(self): def test_dci_h5dataset_scalar(self): data = [1] dci1 = DataChunkIterator(data=data, buffer_size=3) - HDF5IO.__chunked_iter_fill__(self.f, 'test_dataset', dci1) - dset = self.f['test_dataset'] + HDF5IO.__chunked_iter_fill__(self.f, "test_dataset", dci1) + dset = self.f["test_dataset"] dci2 = DataChunkIterator(data=dset, buffer_size=4) # dataset is read such that Nones in original data were not written, but are read as 0s @@ -503,7 +601,10 @@ def test_dci_h5dataset_scalar(self): def test_write_dataset_generic_data_chunk_iterator(self): array = np.arange(10) dci = NumpyArrayGenericDataChunkIterator(array=array) - self.io.write_dataset(self.f, DatasetBuilder("test_dataset", dci, attributes={}, dtype=dci.dtype)) + self.io.write_dataset( + self.f, + DatasetBuilder("test_dataset", dci, attributes={}, dtype=dci.dtype), + ) dset = self.f["test_dataset"] self.assertListEqual(dset[:].tolist(), list(array)) self.assertEqual(dset[:].dtype, dci.dtype) @@ -552,67 +653,93 @@ def test_pass_through_of_chunk_shape_generic_data_chunk_iterator(self): def test_warning_on_non_gzip_compression(self): # Make sure no warning is issued when using gzip with warnings.catch_warnings(record=True) as w: - dset = H5DataIO(np.arange(30), - compression='gzip') + dset = H5DataIO(np.arange(30), compression="gzip") self.assertEqual(len(w), 0) - self.assertEqual(dset.io_settings['compression'], 'gzip') + self.assertEqual(dset.io_settings["compression"], "gzip") # Make sure a warning is issued when using szip (even if installed) - warn_msg = ("szip compression may not be available on all installations of HDF5. Use of gzip is " - "recommended to ensure portability of the generated HDF5 files.") + warn_msg = ( + "szip compression may not be available on all installations of HDF5. Use of" + " gzip is recommended to ensure portability of the generated HDF5 files." + ) if "szip" in h5py_filters.encode: with self.assertWarnsWith(UserWarning, warn_msg): - dset = H5DataIO(np.arange(30), - compression='szip', - compression_opts=('ec', 16)) - self.assertEqual(dset.io_settings['compression'], 'szip') + dset = H5DataIO( + np.arange(30), + compression="szip", + compression_opts=("ec", 16), + ) + self.assertEqual(dset.io_settings["compression"], "szip") else: with self.assertRaises(ValueError): with self.assertWarnsWith(UserWarning, warn_msg): - dset = H5DataIO(np.arange(30), - compression='szip', - compression_opts=('ec', 16)) - self.assertEqual(dset.io_settings['compression'], 'szip') + dset = H5DataIO( + np.arange(30), + compression="szip", + compression_opts=("ec", 16), + ) + self.assertEqual(dset.io_settings["compression"], "szip") # Make sure a warning is issued when using lzf compression - warn_msg = ("lzf compression may not be available on all installations of HDF5. Use of gzip is " - "recommended to ensure portability of the generated HDF5 files.") + warn_msg = ( + "lzf compression may not be available on all installations of HDF5. Use of" + " gzip is recommended to ensure portability of the generated HDF5 files." + ) with self.assertWarnsWith(UserWarning, warn_msg): - dset = H5DataIO(np.arange(30), - compression='lzf') - self.assertEqual(dset.io_settings['compression'], 'lzf') + dset = H5DataIO(np.arange(30), compression="lzf") + self.assertEqual(dset.io_settings["compression"], "lzf") def test_error_on_unsupported_compression_filter(self): # Make sure gzip does not raise an error try: - H5DataIO(np.arange(30), compression='gzip', compression_opts=5) + H5DataIO(np.arange(30), compression="gzip", compression_opts=5) except ValueError: self.fail("Using gzip compression raised a ValueError when it should not") # Make sure szip raises an error if not installed (or does not raise an error if installed) - warn_msg = ("szip compression may not be available on all installations of HDF5. Use of gzip is " - "recommended to ensure portability of the generated HDF5 files.") + warn_msg = ( + "szip compression may not be available on all installations of HDF5. Use of" + " gzip is recommended to ensure portability of the generated HDF5 files." + ) if "szip" not in h5py_filters.encode: with self.assertRaises(ValueError): with self.assertWarnsWith(UserWarning, warn_msg): - H5DataIO(np.arange(30), compression='szip', compression_opts=('ec', 16)) + H5DataIO( + np.arange(30), + compression="szip", + compression_opts=("ec", 16), + ) else: try: with self.assertWarnsWith(UserWarning, warn_msg): - H5DataIO(np.arange(30), compression='szip', compression_opts=('ec', 16)) + H5DataIO( + np.arange(30), + compression="szip", + compression_opts=("ec", 16), + ) except ValueError: self.fail("SZIP is installed but H5DataIO still raises an error") # Test error on illegal (i.e., a made-up compressor) with self.assertRaises(ValueError): - warn_msg = ("unknown compression may not be available on all installations of HDF5. Use of gzip is " - "recommended to ensure portability of the generated HDF5 files.") + warn_msg = ( + "unknown compression may not be available on all installations of HDF5." + " Use of gzip is recommended to ensure portability of the generated" + " HDF5 files." + ) with self.assertWarnsWith(UserWarning, warn_msg): H5DataIO(np.arange(30), compression="unknown") # Make sure passing int compression filter raise an error if not installed if not h5py_filters.h5z.filter_avail(h5py_filters.h5z.FILTER_MAX): with self.assertRaises(ValueError): - warn_msg = ("%i compression may not be available on all installations of HDF5. Use of gzip is " - "recommended to ensure portability of the generated HDF5 files." - % h5py_filters.h5z.FILTER_MAX) + warn_msg = ( + "%i compression may not be available on all installations of HDF5." + " Use of gzip is recommended to ensure portability of the generated" + " HDF5 files." + % h5py_filters.h5z.FILTER_MAX + ) with self.assertWarnsWith(UserWarning, warn_msg): - H5DataIO(np.arange(30), compression=h5py_filters.h5z.FILTER_MAX, allow_plugin_filters=True) + H5DataIO( + np.arange(30), + compression=h5py_filters.h5z.FILTER_MAX, + allow_plugin_filters=True, + ) # Make sure available int compression filters raise an error without passing allow_plugin_filters=True with self.assertRaises(ValueError): H5DataIO(np.arange(30), compression=h5py_filters.h5z.FILTER_DEFLATE) @@ -620,61 +747,66 @@ def test_error_on_unsupported_compression_filter(self): def test_value_error_on_incompatible_compression_opts(self): # Make sure we warn when gzip with szip compression options is used with self.assertRaises(ValueError): - H5DataIO(np.arange(30), compression='gzip', compression_opts=('ec', 16)) + H5DataIO(np.arange(30), compression="gzip", compression_opts=("ec", 16)) # Make sure we warn if gzip with a too high aggression is used with self.assertRaises(ValueError): - H5DataIO(np.arange(30), compression='gzip', compression_opts=100) + H5DataIO(np.arange(30), compression="gzip", compression_opts=100) # Make sure we warn if lzf with gzip compression option is used with self.assertRaises(ValueError): - H5DataIO(np.arange(30), compression='lzf', compression_opts=5) + H5DataIO(np.arange(30), compression="lzf", compression_opts=5) # Make sure we warn if lzf with szip compression option is used with self.assertRaises(ValueError): - H5DataIO(np.arange(30), compression='lzf', compression_opts=('ec', 16)) + H5DataIO(np.arange(30), compression="lzf", compression_opts=("ec", 16)) # Make sure we warn if szip with gzip compression option is used with self.assertRaises(ValueError): - H5DataIO(np.arange(30), compression='szip', compression_opts=4) + H5DataIO(np.arange(30), compression="szip", compression_opts=4) # Make sure szip raises a ValueError if bad options are used (odd compression option) with self.assertRaises(ValueError): - H5DataIO(np.arange(30), compression='szip', compression_opts=('ec', 3)) + H5DataIO(np.arange(30), compression="szip", compression_opts=("ec", 3)) # Make sure szip raises a ValueError if bad options are used (bad methods) with self.assertRaises(ValueError): - H5DataIO(np.arange(30), compression='szip', compression_opts=('bad_method', 16)) + H5DataIO( + np.arange(30), + compression="szip", + compression_opts=("bad_method", 16), + ) def test_warning_on_linking_of_regular_array(self): msg = "link_data parameter in H5DataIO will be ignored" with self.assertWarnsWith(UserWarning, msg): - dset = H5DataIO(np.arange(30), - link_data=True) + dset = H5DataIO(np.arange(30), link_data=True) self.assertEqual(dset.link_data, False) def test_warning_on_setting_io_options_on_h5dataset_input(self): - self.io.write_dataset(self.f, DatasetBuilder('test_dataset', np.arange(10), attributes={})) + self.io.write_dataset(self.f, DatasetBuilder("test_dataset", np.arange(10), attributes={})) msg = "maxshape in H5DataIO will be ignored with H5DataIO.data being an HDF5 dataset" with self.assertWarnsWith(UserWarning, msg): - H5DataIO(self.f['test_dataset'], - compression='gzip', - compression_opts=4, - fletcher32=True, - shuffle=True, - maxshape=(10, 20), - chunks=(10,), - fillvalue=100) + H5DataIO( + self.f["test_dataset"], + compression="gzip", + compression_opts=4, + fletcher32=True, + shuffle=True, + maxshape=(10, 20), + chunks=(10,), + fillvalue=100, + ) def test_h5dataio_array_conversion_numpy(self): # Test that H5DataIO.__array__ is working when wrapping an ndarray - test_speed = np.array([10., 20.]) + test_speed = np.array([10.0, 20.0]) data = H5DataIO((test_speed)) self.assertTrue(np.all(np.isfinite(data))) # Force call of H5DataIO.__array__ def test_h5dataio_array_conversion_list(self): # Test that H5DataIO.__array__ is working when wrapping a python list - test_speed = [10., 20.] + test_speed = [10.0, 20.0] data = H5DataIO(test_speed) self.assertTrue(np.all(np.isfinite(data))) # Force call of H5DataIO.__array__ def test_h5dataio_array_conversion_datachunkiterator(self): # Test that H5DataIO.__array__ is working when wrapping a python list - test_speed = DataChunkIterator(data=[10., 20.]) + test_speed = DataChunkIterator(data=[10.0, 20.0]) data = H5DataIO(test_speed) with self.assertRaises(NotImplementedError): np.isfinite(data) # Force call of H5DataIO.__array__ @@ -683,65 +815,86 @@ def test_h5dataio_array_conversion_datachunkiterator(self): # Copy/Link h5py.Dataset object ############################################# def test_link_h5py_dataset_input(self): - self.io.write_dataset(self.f, DatasetBuilder('test_dataset', np.arange(10), attributes={})) - self.io.write_dataset(self.f, DatasetBuilder('test_softlink', self.f['test_dataset'], attributes={})) - self.assertTrue(isinstance(self.f.get('test_softlink', getlink=True), SoftLink)) + self.io.write_dataset(self.f, DatasetBuilder("test_dataset", np.arange(10), attributes={})) + self.io.write_dataset( + self.f, + DatasetBuilder("test_softlink", self.f["test_dataset"], attributes={}), + ) + self.assertTrue(isinstance(self.f.get("test_softlink", getlink=True), SoftLink)) def test_copy_h5py_dataset_input(self): - self.io.write_dataset(self.f, DatasetBuilder('test_dataset', np.arange(10), attributes={})) - self.io.write_dataset(self.f, - DatasetBuilder('test_copy', self.f['test_dataset'], attributes={}), - link_data=False) - self.assertTrue(isinstance(self.f.get('test_copy', getlink=True), HardLink)) - self.assertListEqual(self.f['test_dataset'][:].tolist(), - self.f['test_copy'][:].tolist()) + self.io.write_dataset(self.f, DatasetBuilder("test_dataset", np.arange(10), attributes={})) + self.io.write_dataset( + self.f, + DatasetBuilder("test_copy", self.f["test_dataset"], attributes={}), + link_data=False, + ) + self.assertTrue(isinstance(self.f.get("test_copy", getlink=True), HardLink)) + self.assertListEqual(self.f["test_dataset"][:].tolist(), self.f["test_copy"][:].tolist()) def test_link_h5py_dataset_h5dataio_input(self): - self.io.write_dataset(self.f, DatasetBuilder('test_dataset', np.arange(10), attributes={})) - self.io.write_dataset(self.f, DatasetBuilder('test_softlink', - H5DataIO(data=self.f['test_dataset'], - link_data=True), - attributes={})) - self.assertTrue(isinstance(self.f.get('test_softlink', getlink=True), SoftLink)) + self.io.write_dataset(self.f, DatasetBuilder("test_dataset", np.arange(10), attributes={})) + self.io.write_dataset( + self.f, + DatasetBuilder( + "test_softlink", + H5DataIO(data=self.f["test_dataset"], link_data=True), + attributes={}, + ), + ) + self.assertTrue(isinstance(self.f.get("test_softlink", getlink=True), SoftLink)) def test_copy_h5py_dataset_h5dataio_input(self): - self.io.write_dataset(self.f, DatasetBuilder('test_dataset', np.arange(10), attributes={})) - self.io.write_dataset(self.f, - DatasetBuilder('test_copy', - H5DataIO(data=self.f['test_dataset'], - link_data=False), # Force dataset copy - attributes={})) # Make sure the default behavior is set to link the data - self.assertTrue(isinstance(self.f.get('test_copy', getlink=True), HardLink)) - self.assertListEqual(self.f['test_dataset'][:].tolist(), - self.f['test_copy'][:].tolist()) + self.io.write_dataset(self.f, DatasetBuilder("test_dataset", np.arange(10), attributes={})) + self.io.write_dataset( + self.f, + DatasetBuilder( + "test_copy", + H5DataIO(data=self.f["test_dataset"], link_data=False), # Force dataset copy + attributes={}, + ), + ) # Make sure the default behavior is set to link the data + self.assertTrue(isinstance(self.f.get("test_copy", getlink=True), HardLink)) + self.assertListEqual(self.f["test_dataset"][:].tolist(), self.f["test_copy"][:].tolist()) def test_list_fill_empty(self): - dset = self.io.__list_fill__(self.f, 'empty_dataset', [], options={'dtype': int, 'io_settings': {}}) + dset = self.io.__list_fill__( + self.f, + "empty_dataset", + [], + options={"dtype": int, "io_settings": {}}, + ) self.assertTupleEqual(dset.shape, (0,)) def test_list_fill_empty_no_dtype(self): with self.assertRaisesRegex(Exception, r"cannot add \S+ to [/\S]+ - could not determine type"): - self.io.__list_fill__(self.f, 'empty_dataset', []) + self.io.__list_fill__(self.f, "empty_dataset", []) def test_read_str(self): - a = ['a', 'bb', 'ccc', 'dddd', 'e'] - attr = 'foobar' - self.io.write_dataset(self.f, DatasetBuilder('test_dataset', a, attributes={'test_attr': attr}, dtype='text')) + a = ["a", "bb", "ccc", "dddd", "e"] + attr = "foobar" + self.io.write_dataset( + self.f, + DatasetBuilder("test_dataset", a, attributes={"test_attr": attr}, dtype="text"), + ) self.io.close() - with HDF5IO(self.path, 'r') as io: + with HDF5IO(self.path, "r") as io: bldr = io.read_builder() - np.array_equal(bldr['test_dataset'].data[:], ['a', 'bb', 'ccc', 'dddd', 'e']) - np.array_equal(bldr['test_dataset'].attributes['test_attr'], attr) + np.array_equal(bldr["test_dataset"].data[:], ["a", "bb", "ccc", "dddd", "e"]) + np.array_equal(bldr["test_dataset"].attributes["test_attr"], attr) if H5PY_3: - self.assertEqual(str(bldr['test_dataset'].data), - '') + self.assertEqual( + str(bldr["test_dataset"].data), + '', + ) else: - self.assertEqual(str(bldr['test_dataset'].data), - '') + self.assertEqual( + str(bldr["test_dataset"].data), + '', + ) class TestRoundTrip(TestCase): - def setUp(self): self.manager = get_foo_buildmanager() self.path = get_temp_filepath() @@ -752,64 +905,70 @@ def tearDown(self): def test_roundtrip_basic(self): # Setup all the data we need - foo1 = Foo('foo1', [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) - foobucket = FooBucket('bucket1', [foo1]) + foo1 = Foo("foo1", [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) + foobucket = FooBucket("bucket1", [foo1]) foofile = FooFile(buckets=[foobucket]) - with HDF5IO(self.path, manager=self.manager, mode='w') as io: + with HDF5IO(self.path, manager=self.manager, mode="w") as io: io.write(foofile) - with HDF5IO(self.path, manager=self.manager, mode='r') as io: + with HDF5IO(self.path, manager=self.manager, mode="r") as io: read_foofile = io.read() - self.assertListEqual(foofile.buckets['bucket1'].foos['foo1'].my_data, - read_foofile.buckets['bucket1'].foos['foo1'].my_data[:].tolist()) + self.assertListEqual( + foofile.buckets["bucket1"].foos["foo1"].my_data, + read_foofile.buckets["bucket1"].foos["foo1"].my_data[:].tolist(), + ) def test_roundtrip_empty_dataset(self): - foo1 = Foo('foo1', [], "I am foo1", 17, 3.14) - foobucket = FooBucket('bucket1', [foo1]) + foo1 = Foo("foo1", [], "I am foo1", 17, 3.14) + foobucket = FooBucket("bucket1", [foo1]) foofile = FooFile(buckets=[foobucket]) - with HDF5IO(self.path, manager=self.manager, mode='w') as io: + with HDF5IO(self.path, manager=self.manager, mode="w") as io: io.write(foofile) - with HDF5IO(self.path, manager=self.manager, mode='r') as io: + with HDF5IO(self.path, manager=self.manager, mode="r") as io: read_foofile = io.read() - self.assertListEqual([], read_foofile.buckets['bucket1'].foos['foo1'].my_data[:].tolist()) + self.assertListEqual( + [], + read_foofile.buckets["bucket1"].foos["foo1"].my_data[:].tolist(), + ) def test_roundtrip_empty_group(self): - foobucket = FooBucket('bucket1', []) + foobucket = FooBucket("bucket1", []) foofile = FooFile(buckets=[foobucket]) - with HDF5IO(self.path, manager=self.manager, mode='w') as io: + with HDF5IO(self.path, manager=self.manager, mode="w") as io: io.write(foofile) - with HDF5IO(self.path, manager=self.manager, mode='r') as io: + with HDF5IO(self.path, manager=self.manager, mode="r") as io: read_foofile = io.read() - self.assertDictEqual({}, read_foofile.buckets['bucket1'].foos) + self.assertDictEqual({}, read_foofile.buckets["bucket1"].foos) def test_roundtrip_pathlib_path(self): pathlib_path = Path(self.path) - foo1 = Foo('foo1', [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) - foobucket = FooBucket('bucket1', [foo1]) + foo1 = Foo("foo1", [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) + foobucket = FooBucket("bucket1", [foo1]) foofile = FooFile([foobucket]) - with HDF5IO(pathlib_path, manager=self.manager, mode='w') as io: + with HDF5IO(pathlib_path, manager=self.manager, mode="w") as io: io.write(foofile) - with HDF5IO(pathlib_path, manager=self.manager, mode='r') as io: + with HDF5IO(pathlib_path, manager=self.manager, mode="r") as io: read_foofile = io.read() - self.assertListEqual(foofile.buckets['bucket1'].foos['foo1'].my_data, - read_foofile.buckets['bucket1'].foos['foo1'].my_data[:].tolist()) + self.assertListEqual( + foofile.buckets["bucket1"].foos["foo1"].my_data, + read_foofile.buckets["bucket1"].foos["foo1"].my_data[:].tolist(), + ) class TestHDF5IO(TestCase): - def setUp(self): self.manager = get_foo_buildmanager() self.path = get_temp_filepath() - foo1 = Foo('foo1', [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) - foobucket = FooBucket('bucket1', [foo1]) + foo1 = Foo("foo1", [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) + foobucket = FooBucket("bucket1", [foo1]) self.foofile = FooFile(buckets=[foobucket]) self.file_obj = None @@ -825,7 +984,7 @@ def tearDown(self): os.remove(fn) def test_constructor(self): - with HDF5IO(self.path, manager=self.manager, mode='w') as io: + with HDF5IO(self.path, manager=self.manager, mode="w") as io: self.assertEqual(io.manager, self.manager) self.assertEqual(io.source, self.path) @@ -834,12 +993,16 @@ def test_delete_with_incomplete_construction_missing_file(self): Here we test what happens when `close` is called before `HDF5IO.__init__` has been completed. In this case, self.__file is missing. """ + class MyHDF5IO(HDF5IO): def __init__(self): self.__open_links = [] raise ValueError("interrupt before HDF5IO.__file is initialized") - with self.assertRaisesWith(exc_type=ValueError, exc_msg="interrupt before HDF5IO.__file is initialized"): + with self.assertRaisesWith( + exc_type=ValueError, + exc_msg="interrupt before HDF5IO.__file is initialized", + ): with MyHDF5IO() as _: pass @@ -848,25 +1011,31 @@ def test_delete_with_incomplete_construction_missing_open_files(self): Here we test what happens when `close` is called before `HDF5IO.__init__` has been completed. In this case, self.__open_files is missing. """ + class MyHDF5IO(HDF5IO): def __init__(self): self.__file = None raise ValueError("interrupt before HDF5IO.__open_files is initialized") - with self.assertRaisesWith(exc_type=ValueError, exc_msg="interrupt before HDF5IO.__open_files is initialized"): + with self.assertRaisesWith( + exc_type=ValueError, + exc_msg="interrupt before HDF5IO.__open_files is initialized", + ): with MyHDF5IO() as _: pass def test_set_file_mismatch(self): - self.file_obj = File(get_temp_filepath(), 'w') - err_msg = ("You argued '%s' as this object's path, but supplied a file with filename: %s" - % (self.path, self.file_obj.filename)) + self.file_obj = File(get_temp_filepath(), "w") + err_msg = "You argued '%s' as this object's path, but supplied a file with filename: %s" % ( + self.path, + self.file_obj.filename, + ) with self.assertRaisesWith(ValueError, err_msg): - HDF5IO(self.path, manager=self.manager, mode='w', file=self.file_obj) + HDF5IO(self.path, manager=self.manager, mode="w", file=self.file_obj) def test_pathlib_path(self): pathlib_path = Path(self.path) - with HDF5IO(pathlib_path, mode='w') as io: + with HDF5IO(pathlib_path, mode="w") as io: self.assertEqual(io.source, self.path) def test_path_or_file(self): @@ -875,7 +1044,6 @@ def test_path_or_file(self): class TestCacheSpec(TestCase): - def setUp(self): self.manager = get_foo_buildmanager() self.path = get_temp_filepath() @@ -885,12 +1053,12 @@ def tearDown(self): os.remove(self.path) def test_cache_spec(self): - foo1 = Foo('foo1', [0, 1, 2, 3, 4], "I am foo1", 17, 3.14) - foo2 = Foo('foo2', [5, 6, 7, 8, 9], "I am foo2", 34, 6.28) - foobucket = FooBucket('bucket1', [foo1, foo2]) + foo1 = Foo("foo1", [0, 1, 2, 3, 4], "I am foo1", 17, 3.14) + foo2 = Foo("foo2", [5, 6, 7, 8, 9], "I am foo2", 34, 6.28) + foobucket = FooBucket("bucket1", [foo1, foo2]) foofile = FooFile(buckets=[foobucket]) - with HDF5IO(self.path, manager=self.manager, mode='w') as io: + with HDF5IO(self.path, manager=self.manager, mode="w") as io: io.write(foofile) ns_catalog = NamespaceCatalog() @@ -902,7 +1070,6 @@ def test_cache_spec(self): class TestNoCacheSpec(TestCase): - def setUp(self): self.manager = get_foo_buildmanager() self.path = get_temp_filepath() @@ -913,25 +1080,24 @@ def tearDown(self): def test_no_cache_spec(self): # Setup all the data we need - foo1 = Foo('foo1', [0, 1, 2, 3, 4], "I am foo1", 17, 3.14) - foo2 = Foo('foo2', [5, 6, 7, 8, 9], "I am foo2", 34, 6.28) - foobucket = FooBucket('bucket1', [foo1, foo2]) + foo1 = Foo("foo1", [0, 1, 2, 3, 4], "I am foo1", 17, 3.14) + foo2 = Foo("foo2", [5, 6, 7, 8, 9], "I am foo2", 34, 6.28) + foobucket = FooBucket("bucket1", [foo1, foo2]) foofile = FooFile(buckets=[foobucket]) - with HDF5IO(self.path, manager=self.manager, mode='w') as io: + with HDF5IO(self.path, manager=self.manager, mode="w") as io: io.write(foofile, cache_spec=False) - with File(self.path, 'r') as f: - self.assertNotIn('specifications', f) + with File(self.path, "r") as f: + self.assertNotIn("specifications", f) class TestMultiWrite(TestCase): - def setUp(self): self.path = get_temp_filepath() - foo1 = Foo('foo1', [0, 1, 2, 3, 4], "I am foo1", 17, 3.14) - foo2 = Foo('foo2', [5, 6, 7, 8, 9], "I am foo2", 34, 6.28) - foobucket = FooBucket('bucket1', [foo1, foo2]) + foo1 = Foo("foo1", [0, 1, 2, 3, 4], "I am foo1", 17, 3.14) + foo2 = Foo("foo2", [5, 6, 7, 8, 9], "I am foo2", 34, 6.28) + foobucket = FooBucket("bucket1", [foo1, foo2]) self.foofile = FooFile(buckets=[foobucket]) def tearDown(self): @@ -940,116 +1106,116 @@ def tearDown(self): def test_double_write_new_manager(self): """Test writing to a container in write mode twice using a new manager without changing the container.""" - with HDF5IO(self.path, manager=get_foo_buildmanager(), mode='w') as io: + with HDF5IO(self.path, manager=get_foo_buildmanager(), mode="w") as io: io.write(self.foofile) - with HDF5IO(self.path, manager=get_foo_buildmanager(), mode='w') as io: + with HDF5IO(self.path, manager=get_foo_buildmanager(), mode="w") as io: io.write(self.foofile) # check that new bucket was written - with HDF5IO(self.path, manager=get_foo_buildmanager(), mode='r') as io: + with HDF5IO(self.path, manager=get_foo_buildmanager(), mode="r") as io: read_foofile = io.read() self.assertContainerEqual(read_foofile, self.foofile) def test_double_write_same_manager(self): """Test writing to a container in write mode twice using the same manager without changing the container.""" manager = get_foo_buildmanager() - with HDF5IO(self.path, manager=manager, mode='w') as io: + with HDF5IO(self.path, manager=manager, mode="w") as io: io.write(self.foofile) - with HDF5IO(self.path, manager=manager, mode='w') as io: + with HDF5IO(self.path, manager=manager, mode="w") as io: io.write(self.foofile) # check that new bucket was written - with HDF5IO(self.path, manager=get_foo_buildmanager(), mode='r') as io: + with HDF5IO(self.path, manager=get_foo_buildmanager(), mode="r") as io: read_foofile = io.read() self.assertContainerEqual(read_foofile, self.foofile) - @unittest.skip('Functionality not yet supported') + @unittest.skip("Functionality not yet supported") def test_double_append_new_manager(self): """Test writing to a container in append mode twice using a new manager without changing the container.""" - with HDF5IO(self.path, manager=get_foo_buildmanager(), mode='a') as io: + with HDF5IO(self.path, manager=get_foo_buildmanager(), mode="a") as io: io.write(self.foofile) - with HDF5IO(self.path, manager=get_foo_buildmanager(), mode='a') as io: + with HDF5IO(self.path, manager=get_foo_buildmanager(), mode="a") as io: io.write(self.foofile) # check that new bucket was written - with HDF5IO(self.path, manager=get_foo_buildmanager(), mode='r') as io: + with HDF5IO(self.path, manager=get_foo_buildmanager(), mode="r") as io: read_foofile = io.read() self.assertContainerEqual(read_foofile, self.foofile) - @unittest.skip('Functionality not yet supported') + @unittest.skip("Functionality not yet supported") def test_double_append_same_manager(self): """Test writing to a container in append mode twice using the same manager without changing the container.""" manager = get_foo_buildmanager() - with HDF5IO(self.path, manager=manager, mode='a') as io: + with HDF5IO(self.path, manager=manager, mode="a") as io: io.write(self.foofile) - with HDF5IO(self.path, manager=manager, mode='a') as io: + with HDF5IO(self.path, manager=manager, mode="a") as io: io.write(self.foofile) # check that new bucket was written - with HDF5IO(self.path, manager=get_foo_buildmanager(), mode='r') as io: + with HDF5IO(self.path, manager=get_foo_buildmanager(), mode="r") as io: read_foofile = io.read() self.assertContainerEqual(read_foofile, self.foofile) def test_write_add_write(self): """Test writing a container, adding to the in-memory container, then overwriting the same file.""" manager = get_foo_buildmanager() - with HDF5IO(self.path, manager=manager, mode='w') as io: + with HDF5IO(self.path, manager=manager, mode="w") as io: io.write(self.foofile) # append new container to in-memory container - foo3 = Foo('foo3', [10, 20], "I am foo3", 2, 0.1) - new_bucket1 = FooBucket('new_bucket1', [foo3]) + foo3 = Foo("foo3", [10, 20], "I am foo3", 2, 0.1) + new_bucket1 = FooBucket("new_bucket1", [foo3]) self.foofile.add_bucket(new_bucket1) # write to same file with same manager, overwriting existing file - with HDF5IO(self.path, manager=manager, mode='w') as io: + with HDF5IO(self.path, manager=manager, mode="w") as io: io.write(self.foofile) # check that new bucket was written - with HDF5IO(self.path, manager=get_foo_buildmanager(), mode='r') as io: + with HDF5IO(self.path, manager=get_foo_buildmanager(), mode="r") as io: read_foofile = io.read() self.assertEqual(len(read_foofile.buckets), 2) - self.assertContainerEqual(read_foofile.buckets['new_bucket1'], new_bucket1) + self.assertContainerEqual(read_foofile.buckets["new_bucket1"], new_bucket1) def test_write_add_append_bucket(self): """Test appending a container to a file.""" manager = get_foo_buildmanager() - with HDF5IO(self.path, manager=manager, mode='w') as io: + with HDF5IO(self.path, manager=manager, mode="w") as io: io.write(self.foofile) - foo3 = Foo('foo3', [10, 20], "I am foo3", 2, 0.1) - new_bucket1 = FooBucket('new_bucket1', [foo3]) + foo3 = Foo("foo3", [10, 20], "I am foo3", 2, 0.1) + new_bucket1 = FooBucket("new_bucket1", [foo3]) # append to same file with same manager, overwriting existing file - with HDF5IO(self.path, manager=manager, mode='a') as io: + with HDF5IO(self.path, manager=manager, mode="a") as io: read_foofile = io.read() # append to read container and call write read_foofile.add_bucket(new_bucket1) io.write(read_foofile) # check that new bucket was written - with HDF5IO(self.path, manager=get_foo_buildmanager(), mode='r') as io: + with HDF5IO(self.path, manager=get_foo_buildmanager(), mode="r") as io: read_foofile = io.read() self.assertEqual(len(read_foofile.buckets), 2) - self.assertContainerEqual(read_foofile.buckets['new_bucket1'], new_bucket1) + self.assertContainerEqual(read_foofile.buckets["new_bucket1"], new_bucket1) def test_write_add_append_double_write(self): """Test using the same IO object to append a container to a file twice.""" manager = get_foo_buildmanager() - with HDF5IO(self.path, manager=manager, mode='w') as io: + with HDF5IO(self.path, manager=manager, mode="w") as io: io.write(self.foofile) - foo3 = Foo('foo3', [10, 20], "I am foo3", 2, 0.1) - new_bucket1 = FooBucket('new_bucket1', [foo3]) - foo4 = Foo('foo4', [10, 20], "I am foo4", 2, 0.1) - new_bucket2 = FooBucket('new_bucket2', [foo4]) + foo3 = Foo("foo3", [10, 20], "I am foo3", 2, 0.1) + new_bucket1 = FooBucket("new_bucket1", [foo3]) + foo4 = Foo("foo4", [10, 20], "I am foo4", 2, 0.1) + new_bucket2 = FooBucket("new_bucket2", [foo4]) # append to same file with same manager, overwriting existing file - with HDF5IO(self.path, manager=manager, mode='a') as io: + with HDF5IO(self.path, manager=manager, mode="a") as io: read_foofile = io.read() # append to read container and call write read_foofile.add_bucket(new_bucket1) @@ -1060,11 +1226,11 @@ def test_write_add_append_double_write(self): io.write(read_foofile) # check that both new buckets were written - with HDF5IO(self.path, manager=get_foo_buildmanager(), mode='r') as io: + with HDF5IO(self.path, manager=get_foo_buildmanager(), mode="r") as io: read_foofile = io.read() self.assertEqual(len(read_foofile.buckets), 3) - self.assertContainerEqual(read_foofile.buckets['new_bucket1'], new_bucket1) - self.assertContainerEqual(read_foofile.buckets['new_bucket2'], new_bucket2) + self.assertContainerEqual(read_foofile.buckets["new_bucket1"], new_bucket1) + self.assertContainerEqual(read_foofile.buckets["new_bucket2"], new_bucket2) class HDF5IOMultiFileTest(TestCase): @@ -1077,7 +1243,7 @@ def setUp(self): # On Windows h5py cannot truncate an open file in write mode. # The temp file will be closed before h5py truncates it # and will be removed during the tearDown step. - self.io = [HDF5IO(i, mode='a', manager=get_foo_buildmanager()) for i in self.paths] + self.io = [HDF5IO(i, mode="a", manager=get_foo_buildmanager()) for i in self.paths] self.f = [i._file for i in self.io] def tearDown(self): @@ -1096,8 +1262,8 @@ def tearDown(self): def test_copy_file_with_external_links(self): # Create the first file - foo1 = Foo('foo1', [0, 1, 2, 3, 4], "I am foo1", 17, 3.14) - bucket1 = FooBucket('bucket1', [foo1]) + foo1 = Foo("foo1", [0, 1, 2, 3, 4], "I am foo1", 17, 3.14) + bucket1 = FooBucket("bucket1", [foo1]) foofile1 = FooFile(buckets=[bucket1]) # Write the first file @@ -1105,8 +1271,14 @@ def test_copy_file_with_external_links(self): # Create the second file read_foofile1 = self.io[0].read() - foo2 = Foo('foo2', read_foofile1.buckets['bucket1'].foos['foo1'].my_data, "I am foo2", 34, 6.28) - bucket2 = FooBucket('bucket2', [foo2]) + foo2 = Foo( + "foo2", + read_foofile1.buckets["bucket1"].foos["foo1"].my_data, + "I am foo2", + 34, + 6.28, + ) + bucket2 = FooBucket("bucket2", [foo2]) foofile2 = FooFile(buckets=[bucket2]) # Write the second file self.io[1].write(foofile2) @@ -1117,26 +1289,36 @@ def test_copy_file_with_external_links(self): self.io[2].close() with self.assertWarns(DeprecationWarning): - HDF5IO.copy_file(source_filename=self.paths[1], - dest_filename=self.paths[2], - expand_external=True, - expand_soft=False, - expand_refs=False) + HDF5IO.copy_file( + source_filename=self.paths[1], + dest_filename=self.paths[2], + expand_external=True, + expand_soft=False, + expand_refs=False, + ) # Test that everything is working as expected # Confirm that our original data file is correct - f1 = File(self.paths[0], 'r') - self.assertIsInstance(f1.get('/buckets/bucket1/foo_holder/foo1/my_data', getlink=True), HardLink) + f1 = File(self.paths[0], "r") + self.assertIsInstance( + f1.get("/buckets/bucket1/foo_holder/foo1/my_data", getlink=True), + HardLink, + ) # Confirm that we successfully created and External Link in our second file - f2 = File(self.paths[1], 'r') - self.assertIsInstance(f2.get('/buckets/bucket2/foo_holder/foo2/my_data', getlink=True), ExternalLink) + f2 = File(self.paths[1], "r") + self.assertIsInstance( + f2.get("/buckets/bucket2/foo_holder/foo2/my_data", getlink=True), + ExternalLink, + ) # Confirm that we successfully resolved the External Link when we copied our second file - f3 = File(self.paths[2], 'r') - self.assertIsInstance(f3.get('/buckets/bucket2/foo_holder/foo2/my_data', getlink=True), HardLink) + f3 = File(self.paths[2], "r") + self.assertIsInstance( + f3.get("/buckets/bucket2/foo_holder/foo2/my_data", getlink=True), + HardLink, + ) class TestCloseLinks(TestCase): - def setUp(self): self.path1 = get_temp_filepath() self.path2 = get_temp_filepath() @@ -1148,59 +1330,57 @@ def tearDown(self): os.remove(self.path2) def test_close_linked_files_auto(self): - """Test closing a file with close_links=True (default). - """ + """Test closing a file with close_links=True (default).""" # Create the first file - foo1 = Foo('foo1', [0, 1, 2, 3, 4], "I am foo1", 17, 3.14) - bucket1 = FooBucket('bucket1', [foo1]) + foo1 = Foo("foo1", [0, 1, 2, 3, 4], "I am foo1", 17, 3.14) + bucket1 = FooBucket("bucket1", [foo1]) foofile1 = FooFile(buckets=[bucket1]) # Write the first file - with HDF5IO(self.path1, mode='w', manager=get_foo_buildmanager()) as io: + with HDF5IO(self.path1, mode="w", manager=get_foo_buildmanager()) as io: io.write(foofile1) # Create the second file manager = get_foo_buildmanager() # use the same manager for read and write so that links work - with HDF5IO(self.path1, mode='r', manager=manager) as read_io: + with HDF5IO(self.path1, mode="r", manager=manager) as read_io: read_foofile1 = read_io.read() - foofile2 = FooFile(foo_link=read_foofile1.buckets['bucket1'].foos['foo1']) # cross-file link + foofile2 = FooFile(foo_link=read_foofile1.buckets["bucket1"].foos["foo1"]) # cross-file link # Write the second file - with HDF5IO(self.path2, mode='w', manager=manager) as write_io: + with HDF5IO(self.path2, mode="w", manager=manager) as write_io: write_io.write(foofile2) - with HDF5IO(self.path2, mode='a', manager=get_foo_buildmanager()) as new_io1: + with HDF5IO(self.path2, mode="a", manager=get_foo_buildmanager()) as new_io1: read_foofile2 = new_io1.read() # keep reference to container in memory self.assertFalse(read_foofile2.foo_link.my_data) # should be able to reopen both files - with HDF5IO(self.path1, mode='a', manager=get_foo_buildmanager()) as new_io3: + with HDF5IO(self.path1, mode="a", manager=get_foo_buildmanager()) as new_io3: new_io3.read() def test_close_linked_files_explicit(self): - """Test closing a file with close_links=False and calling close_linked_files(). - """ + """Test closing a file with close_links=False and calling close_linked_files().""" # Create the first file - foo1 = Foo('foo1', [0, 1, 2, 3, 4], "I am foo1", 17, 3.14) - bucket1 = FooBucket('bucket1', [foo1]) + foo1 = Foo("foo1", [0, 1, 2, 3, 4], "I am foo1", 17, 3.14) + bucket1 = FooBucket("bucket1", [foo1]) foofile1 = FooFile(buckets=[bucket1]) # Write the first file - with HDF5IO(self.path1, mode='w', manager=get_foo_buildmanager()) as io: + with HDF5IO(self.path1, mode="w", manager=get_foo_buildmanager()) as io: io.write(foofile1) # Create the second file manager = get_foo_buildmanager() # use the same manager for read and write so that links work - with HDF5IO(self.path1, mode='r', manager=manager) as read_io: + with HDF5IO(self.path1, mode="r", manager=manager) as read_io: read_foofile1 = read_io.read() - foofile2 = FooFile(foo_link=read_foofile1.buckets['bucket1'].foos['foo1']) # cross-file link + foofile2 = FooFile(foo_link=read_foofile1.buckets["bucket1"].foos["foo1"]) # cross-file link # Write the second file - with HDF5IO(self.path2, mode='w', manager=manager) as write_io: + with HDF5IO(self.path2, mode="w", manager=manager) as write_io: write_io.write(foofile2) - new_io1 = HDF5IO(self.path2, mode='a', manager=get_foo_buildmanager()) + new_io1 = HDF5IO(self.path2, mode="a", manager=get_foo_buildmanager()) read_foofile2 = new_io1.read() # keep reference to container in memory new_io1.close(close_links=False) # do not close the links @@ -1209,32 +1389,31 @@ def test_close_linked_files_explicit(self): self.assertFalse(read_foofile2.foo_link.my_data) # should be able to reopen both files - with HDF5IO(self.path1, mode='a', manager=get_foo_buildmanager()) as new_io3: + with HDF5IO(self.path1, mode="a", manager=get_foo_buildmanager()) as new_io3: new_io3.read() def test_close_links_manually_and_close(self): - """Test closing a file with close_links=False, manually closing open links, and calling close_linked_files(). - """ + """Test closing a file with close_links=False, manually closing open links, and calling close_linked_files().""" # Create the first file - foo1 = Foo('foo1', [0, 1, 2, 3, 4], "I am foo1", 17, 3.14) - bucket1 = FooBucket('bucket1', [foo1]) + foo1 = Foo("foo1", [0, 1, 2, 3, 4], "I am foo1", 17, 3.14) + bucket1 = FooBucket("bucket1", [foo1]) foofile1 = FooFile(buckets=[bucket1]) # Write the first file - with HDF5IO(self.path1, mode='w', manager=get_foo_buildmanager()) as io: + with HDF5IO(self.path1, mode="w", manager=get_foo_buildmanager()) as io: io.write(foofile1) # Create the second file manager = get_foo_buildmanager() # use the same manager for read and write so that links work - with HDF5IO(self.path1, mode='r', manager=manager) as read_io: + with HDF5IO(self.path1, mode="r", manager=manager) as read_io: read_foofile1 = read_io.read() - foofile2 = FooFile(foo_link=read_foofile1.buckets['bucket1'].foos['foo1']) # cross-file link + foofile2 = FooFile(foo_link=read_foofile1.buckets["bucket1"].foos["foo1"]) # cross-file link # Write the second file - with HDF5IO(self.path2, mode='w', manager=manager) as write_io: + with HDF5IO(self.path2, mode="w", manager=manager) as write_io: write_io.write(foofile2) - new_io1 = HDF5IO(self.path2, mode='a', manager=get_foo_buildmanager()) + new_io1 = HDF5IO(self.path2, mode="a", manager=get_foo_buildmanager()) read_foofile2 = new_io1.read() # keep reference to container in memory new_io1.close(close_links=False) # do not close the links @@ -1244,31 +1423,30 @@ def test_close_links_manually_and_close(self): new_io1.close_linked_files() # make sure this does not fail because the linked-to file is already closed def test_close_linked_files_not_disruptive(self): - """Test closing a file with close_links=True (default) does not interfere with other open file handles. - """ + """Test closing a file with close_links=True (default) does not interfere with other open file handles.""" # Create the first file - foo1 = Foo('foo1', [0, 1, 2, 3, 4], "I am foo1", 17, 3.14) - bucket1 = FooBucket('bucket1', [foo1]) + foo1 = Foo("foo1", [0, 1, 2, 3, 4], "I am foo1", 17, 3.14) + bucket1 = FooBucket("bucket1", [foo1]) foofile1 = FooFile(buckets=[bucket1]) # Write the first file - with HDF5IO(self.path1, mode='w', manager=get_foo_buildmanager()) as io: + with HDF5IO(self.path1, mode="w", manager=get_foo_buildmanager()) as io: io.write(foofile1) # Create the second file manager = get_foo_buildmanager() # use the same manager for read and write so that links work - with HDF5IO(self.path1, mode='r', manager=manager) as read_io: + with HDF5IO(self.path1, mode="r", manager=manager) as read_io: read_foofile1 = read_io.read() - foofile2 = FooFile(foo_link=read_foofile1.buckets['bucket1'].foos['foo1']) # cross-file link + foofile2 = FooFile(foo_link=read_foofile1.buckets["bucket1"].foos["foo1"]) # cross-file link # Write the second file - with HDF5IO(self.path2, mode='w', manager=manager) as write_io: + with HDF5IO(self.path2, mode="w", manager=manager) as write_io: write_io.write(foofile2) - read_io = HDF5IO(self.path1, mode='r', manager=manager) + read_io = HDF5IO(self.path1, mode="r", manager=manager) read_foofile1 = read_io.read() - with HDF5IO(self.path2, mode='r', manager=get_foo_buildmanager()) as new_io1: + with HDF5IO(self.path2, mode="r", manager=get_foo_buildmanager()) as new_io1: new_io1.read() # keep reference to container in memory self.assertTrue(read_io) # make sure read_io is not closed @@ -1276,23 +1454,27 @@ def test_close_linked_files_not_disruptive(self): class HDF5IOInitNoFileTest(TestCase): - """ Test if file does not exist, init with mode (r, r+) throws error, all others succeed """ + """Test if file does not exist, init with mode (r, r+) throws error, all others succeed""" def test_init_no_file_r(self): self.path = "test_init_nofile_r.h5" - with self.assertRaisesWith(UnsupportedOperation, - "Unable to open file %s in 'r' mode. File does not exist." % self.path): - HDF5IO(self.path, mode='r') + with self.assertRaisesWith( + UnsupportedOperation, + "Unable to open file %s in 'r' mode. File does not exist." % self.path, + ): + HDF5IO(self.path, mode="r") def test_init_no_file_rplus(self): self.path = "test_init_nofile_rplus.h5" - with self.assertRaisesWith(UnsupportedOperation, - "Unable to open file %s in 'r+' mode. File does not exist." % self.path): - HDF5IO(self.path, mode='r+') + with self.assertRaisesWith( + UnsupportedOperation, + "Unable to open file %s in 'r+' mode. File does not exist." % self.path, + ): + HDF5IO(self.path, mode="r+") def test_init_no_file_ok(self): # test that no errors are thrown - modes = ('w', 'w-', 'x', 'a') + modes = ("w", "w-", "x", "a") for m in modes: self.path = "test_init_nofile.h5" with HDF5IO(self.path, mode=m): @@ -1302,11 +1484,11 @@ def test_init_no_file_ok(self): class HDF5IOInitFileExistsTest(TestCase): - """ Test if file exists, init with mode w-/x throws error, all others succeed """ + """Test if file exists, init with mode w-/x throws error, all others succeed""" def setUp(self): self.path = get_temp_filepath() - temp_io = HDF5IO(self.path, mode='w') + temp_io = HDF5IO(self.path, mode="w") temp_io.close() self.io = None @@ -1318,29 +1500,33 @@ def tearDown(self): os.remove(self.path) def test_init_wminus_file_exists(self): - with self.assertRaisesWith(UnsupportedOperation, - "Unable to open file %s in 'w-' mode. File already exists." % self.path): - self.io = HDF5IO(self.path, mode='w-') + with self.assertRaisesWith( + UnsupportedOperation, + "Unable to open file %s in 'w-' mode. File already exists." % self.path, + ): + self.io = HDF5IO(self.path, mode="w-") def test_init_x_file_exists(self): - with self.assertRaisesWith(UnsupportedOperation, - "Unable to open file %s in 'x' mode. File already exists." % self.path): - self.io = HDF5IO(self.path, mode='x') + with self.assertRaisesWith( + UnsupportedOperation, + "Unable to open file %s in 'x' mode. File already exists." % self.path, + ): + self.io = HDF5IO(self.path, mode="x") def test_init_file_exists_ok(self): # test that no errors are thrown - modes = ('r', 'r+', 'w', 'a') + modes = ("r", "r+", "w", "a") for m in modes: with HDF5IO(self.path, mode=m): pass class HDF5IOReadNoDataTest(TestCase): - """ Test if file exists and there is no data, read with mode (r, r+, a) throws error """ + """Test if file exists and there is no data, read with mode (r, r+, a) throws error""" def setUp(self): self.path = get_temp_filepath() - temp_io = HDF5IO(self.path, mode='w') + temp_io = HDF5IO(self.path, mode="w") temp_io.close() self.io = None @@ -1353,36 +1539,42 @@ def tearDown(self): os.remove(self.path) def test_read_no_data_r(self): - self.io = HDF5IO(self.path, mode='r') - with self.assertRaisesWith(UnsupportedOperation, - "Cannot read data from file %s in mode 'r'. There are no values." % self.path): + self.io = HDF5IO(self.path, mode="r") + with self.assertRaisesWith( + UnsupportedOperation, + "Cannot read data from file %s in mode 'r'. There are no values." % self.path, + ): self.io.read() def test_read_no_data_rplus(self): - self.io = HDF5IO(self.path, mode='r+') - with self.assertRaisesWith(UnsupportedOperation, - "Cannot read data from file %s in mode 'r+'. There are no values." % self.path): + self.io = HDF5IO(self.path, mode="r+") + with self.assertRaisesWith( + UnsupportedOperation, + "Cannot read data from file %s in mode 'r+'. There are no values." % self.path, + ): self.io.read() def test_read_no_data_a(self): - self.io = HDF5IO(self.path, mode='a') - with self.assertRaisesWith(UnsupportedOperation, - "Cannot read data from file %s in mode 'a'. There are no values." % self.path): + self.io = HDF5IO(self.path, mode="a") + with self.assertRaisesWith( + UnsupportedOperation, + "Cannot read data from file %s in mode 'a'. There are no values." % self.path, + ): self.io.read() class HDF5IOReadData(TestCase): - """ Test if file exists and there is no data, read in mode (r, r+, a) is ok + """Test if file exists and there is no data, read in mode (r, r+, a) is ok and read in mode w throws error """ def setUp(self): self.path = get_temp_filepath() - foo1 = Foo('foo1', [0, 1, 2, 3, 4], "I am foo1", 17, 3.14) - bucket1 = FooBucket('bucket1', [foo1]) + foo1 = Foo("foo1", [0, 1, 2, 3, 4], "I am foo1", 17, 3.14) + bucket1 = FooBucket("bucket1", [foo1]) self.foofile1 = FooFile(buckets=[bucket1]) - with HDF5IO(self.path, manager=get_foo_buildmanager(), mode='w') as temp_io: + with HDF5IO(self.path, manager=get_foo_buildmanager(), mode="w") as temp_io: temp_io.write(self.foofile1) self.io = None @@ -1394,27 +1586,30 @@ def tearDown(self): os.remove(self.path) def test_read_file_ok(self): - modes = ('r', 'r+', 'a') + modes = ("r", "r+", "a") for m in modes: with HDF5IO(self.path, manager=get_foo_buildmanager(), mode=m) as io: io.read() def test_read_file_w(self): - with HDF5IO(self.path, manager=get_foo_buildmanager(), mode='w') as io: - with self.assertRaisesWith(UnsupportedOperation, - "Cannot read from file %s in mode 'w'. Please use mode 'r', 'r+', or 'a'." - % self.path): + with HDF5IO(self.path, manager=get_foo_buildmanager(), mode="w") as io: + with self.assertRaisesWith( + UnsupportedOperation, + "Cannot read from file %s in mode 'w'. Please use mode 'r', 'r+', or 'a'." % self.path, + ): read_foofile1 = io.read() - self.assertListEqual(self.foofile1.buckets['bucket1'].foos['foo1'].my_data, - read_foofile1.buckets['bucket1'].foos['foo1'].my_data[:].tolist()) + self.assertListEqual( + self.foofile1.buckets["bucket1"].foos["foo1"].my_data, + read_foofile1.buckets["bucket1"].foos["foo1"].my_data[:].tolist(), + ) class HDF5IOReadBuilderClosed(TestCase): - """Test if file exists but is closed, then read_builder raises an error. """ + """Test if file exists but is closed, then read_builder raises an error.""" def setUp(self): self.path = get_temp_filepath() - temp_io = HDF5IO(self.path, mode='w') + temp_io = HDF5IO(self.path, mode="w") temp_io.close() self.io = None @@ -1427,7 +1622,7 @@ def tearDown(self): os.remove(self.path) def test_read_closed(self): - self.io = HDF5IO(self.path, mode='r') + self.io = HDF5IO(self.path, mode="r") self.io.close() msg = "Cannot read data from closed HDF5 file '%s'" % self.path with self.assertRaisesWith(UnsupportedOperation, msg): @@ -1435,55 +1630,57 @@ def test_read_closed(self): class HDF5IOWriteNoFile(TestCase): - """ Test if file does not exist, write in mode (w, w-, x, a) is ok """ + """Test if file does not exist, write in mode (w, w-, x, a) is ok""" def setUp(self): - foo1 = Foo('foo1', [0, 1, 2, 3, 4], "I am foo1", 17, 3.14) - bucket1 = FooBucket('bucket1', [foo1]) + foo1 = Foo("foo1", [0, 1, 2, 3, 4], "I am foo1", 17, 3.14) + bucket1 = FooBucket("bucket1", [foo1]) self.foofile1 = FooFile(buckets=[bucket1]) - self.path = 'test_write_nofile.h5' + self.path = "test_write_nofile.h5" def tearDown(self): if os.path.exists(self.path): os.remove(self.path) def test_write_no_file_w_ok(self): - self.__write_file('w') + self.__write_file("w") def test_write_no_file_wminus_ok(self): - self.__write_file('w-') + self.__write_file("w-") def test_write_no_file_x_ok(self): - self.__write_file('x') + self.__write_file("x") def test_write_no_file_a_ok(self): - self.__write_file('a') + self.__write_file("a") def __write_file(self, mode): with HDF5IO(self.path, manager=get_foo_buildmanager(), mode=mode) as io: io.write(self.foofile1) - with HDF5IO(self.path, manager=get_foo_buildmanager(), mode='r') as io: + with HDF5IO(self.path, manager=get_foo_buildmanager(), mode="r") as io: read_foofile = io.read() - self.assertListEqual(self.foofile1.buckets['bucket1'].foos['foo1'].my_data, - read_foofile.buckets['bucket1'].foos['foo1'].my_data[:].tolist()) + self.assertListEqual( + self.foofile1.buckets["bucket1"].foos["foo1"].my_data, + read_foofile.buckets["bucket1"].foos["foo1"].my_data[:].tolist(), + ) class HDF5IOWriteFileExists(TestCase): - """ Test if file exists, write in mode (r+, w, a) is ok and write in mode r throws error """ + """Test if file exists, write in mode (r+, w, a) is ok and write in mode r throws error""" def setUp(self): self.path = get_temp_filepath() - foo1 = Foo('foo1', [0, 1, 2, 3, 4], "I am foo1", 17, 3.14) - bucket1 = FooBucket('bucket1', [foo1]) + foo1 = Foo("foo1", [0, 1, 2, 3, 4], "I am foo1", 17, 3.14) + bucket1 = FooBucket("bucket1", [foo1]) self.foofile1 = FooFile(buckets=[bucket1]) - foo2 = Foo('foo2', [0, 1, 2, 3, 4], "I am foo2", 17, 3.14) - bucket2 = FooBucket('bucket2', [foo2]) + foo2 = Foo("foo2", [0, 1, 2, 3, 4], "I am foo2", 17, 3.14) + bucket2 = FooBucket("bucket2", [foo2]) self.foofile2 = FooFile(buckets=[bucket2]) - with HDF5IO(self.path, manager=get_foo_buildmanager(), mode='w') as io: + with HDF5IO(self.path, manager=get_foo_buildmanager(), mode="w") as io: io.write(self.foofile1) self.io = None @@ -1495,7 +1692,7 @@ def tearDown(self): os.remove(self.path) def test_write_rplus(self): - with HDF5IO(self.path, manager=get_foo_buildmanager(), mode='r+') as io: + with HDF5IO(self.path, manager=get_foo_buildmanager(), mode="r+") as io: # even though foofile1 and foofile2 have different names, writing a # root object into a file that already has a root object, in r+ mode # should throw an error @@ -1503,7 +1700,7 @@ def test_write_rplus(self): io.write(self.foofile2) def test_write_a(self): - with HDF5IO(self.path, manager=get_foo_buildmanager(), mode='a') as io: + with HDF5IO(self.path, manager=get_foo_buildmanager(), mode="a") as io: # even though foofile1 and foofile2 have different names, writing a # root object into a file that already has a root object, in a mode # should throw an error @@ -1512,30 +1709,32 @@ def test_write_a(self): def test_write_w(self): # mode 'w' should overwrite contents of file - with HDF5IO(self.path, manager=get_foo_buildmanager(), mode='w') as io: + with HDF5IO(self.path, manager=get_foo_buildmanager(), mode="w") as io: io.write(self.foofile2) - with HDF5IO(self.path, manager=get_foo_buildmanager(), mode='r') as io: + with HDF5IO(self.path, manager=get_foo_buildmanager(), mode="r") as io: read_foofile = io.read() - self.assertListEqual(self.foofile2.buckets['bucket2'].foos['foo2'].my_data, - read_foofile.buckets['bucket2'].foos['foo2'].my_data[:].tolist()) + self.assertListEqual( + self.foofile2.buckets["bucket2"].foos["foo2"].my_data, + read_foofile.buckets["bucket2"].foos["foo2"].my_data[:].tolist(), + ) def test_write_r(self): - with HDF5IO(self.path, manager=get_foo_buildmanager(), mode='r') as io: - with self.assertRaisesWith(UnsupportedOperation, - ("Cannot write to file %s in mode 'r'. " - "Please use mode 'r+', 'w', 'w-', 'x', or 'a'") % self.path): + with HDF5IO(self.path, manager=get_foo_buildmanager(), mode="r") as io: + with self.assertRaisesWith( + UnsupportedOperation, + "Cannot write to file %s in mode 'r'. Please use mode 'r+', 'w', 'w-', 'x', or 'a'" % self.path, + ): io.write(self.foofile2) class TestWritten(TestCase): - def setUp(self): self.manager = get_foo_buildmanager() self.path = get_temp_filepath() - foo1 = Foo('foo1', [0, 1, 2, 3, 4], "I am foo1", 17, 3.14) - foo2 = Foo('foo2', [5, 6, 7, 8, 9], "I am foo2", 34, 6.28) - foobucket = FooBucket('bucket1', [foo1, foo2]) + foo1 = Foo("foo1", [0, 1, 2, 3, 4], "I am foo1", 17, 3.14) + foo2 = Foo("foo2", [5, 6, 7, 8, 9], "I am foo2", 34, 6.28) + foobucket = FooBucket("bucket1", [foo1, foo2]) self.foofile = FooFile(buckets=[foobucket]) def tearDown(self): @@ -1544,7 +1743,7 @@ def tearDown(self): def test_set_written_on_write(self): """Test that write_builder changes the written flag of the builder and its children from False to True.""" - with HDF5IO(self.path, manager=self.manager, mode='w') as io: + with HDF5IO(self.path, manager=self.manager, mode="w") as io: builder = self.manager.build(container=self.foofile, source=self.path) self.assertFalse(io.get_written(builder)) self._check_written_children(io, builder, False) @@ -1564,15 +1763,16 @@ def _check_written_children(self, io, builder, val): class H5DataIOValid(TestCase): - def setUp(self): - self.paths = [get_temp_filepath(), ] + self.paths = [ + get_temp_filepath(), + ] - self.foo1 = Foo('foo1', H5DataIO([1, 2, 3, 4, 5]), "I am foo1", 17, 3.14) - bucket1 = FooBucket('bucket1', [self.foo1]) + self.foo1 = Foo("foo1", H5DataIO([1, 2, 3, 4, 5]), "I am foo1", 17, 3.14) + bucket1 = FooBucket("bucket1", [self.foo1]) foofile1 = FooFile(buckets=[bucket1]) - with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode='w') as io: + with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode="w") as io: io.write(foofile1) def tearDown(self): @@ -1585,25 +1785,30 @@ def test_valid(self): def test_read_valid(self): """Test that h5py.H5Dataset.id.valid works as expected""" - with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode='r') as io: + with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode="r") as io: read_foofile1 = io.read() - self.assertTrue(read_foofile1.buckets['bucket1'].foos['foo1'].my_data.id.valid) + self.assertTrue(read_foofile1.buckets["bucket1"].foos["foo1"].my_data.id.valid) - self.assertFalse(read_foofile1.buckets['bucket1'].foos['foo1'].my_data.id.valid) + self.assertFalse(read_foofile1.buckets["bucket1"].foos["foo1"].my_data.id.valid) def test_link(self): - """Test that wrapping of linked data within H5DataIO """ - with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode='r') as io: + """Test that wrapping of linked data within H5DataIO""" + with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode="r") as io: read_foofile1 = io.read() - self.foo2 = Foo('foo2', H5DataIO(data=read_foofile1.buckets['bucket1'].foos['foo1'].my_data), - "I am foo2", 17, 3.14) - bucket2 = FooBucket('bucket2', [self.foo2]) + self.foo2 = Foo( + "foo2", + H5DataIO(data=read_foofile1.buckets["bucket1"].foos["foo1"].my_data), + "I am foo2", + 17, + 3.14, + ) + bucket2 = FooBucket("bucket2", [self.foo2]) foofile2 = FooFile(buckets=[bucket2]) self.paths.append(get_temp_filepath()) - with HDF5IO(self.paths[1], manager=get_foo_buildmanager(), mode='w') as io: + with HDF5IO(self.paths[1], manager=get_foo_buildmanager(), mode="w") as io: io.write(foofile2) self.assertTrue(self.foo2.my_data.valid) # test valid @@ -1613,7 +1818,7 @@ def test_link(self): # test loop through iterable match = [1, 2, 3, 4, 5] - for (i, j) in zip(self.foo2.my_data, match): + for i, j in zip(self.foo2.my_data, match): self.assertEqual(i, j) # test iterator @@ -1626,10 +1831,16 @@ def test_link(self): with self.assertRaisesWith(InvalidDataIOError, "Cannot get length of data. Data is not valid."): len(self.foo2.my_data) - with self.assertRaisesWith(InvalidDataIOError, "Cannot get attribute 'shape' of data. Data is not valid."): + with self.assertRaisesWith( + InvalidDataIOError, + "Cannot get attribute 'shape' of data. Data is not valid.", + ): self.foo2.my_data.shape - with self.assertRaisesWith(InvalidDataIOError, "Cannot convert data to array. Data is not valid."): + with self.assertRaisesWith( + InvalidDataIOError, + "Cannot convert data to array. Data is not valid.", + ): np.array(self.foo2.my_data) with self.assertRaisesWith(InvalidDataIOError, "Cannot iterate on data. Data is not valid."): @@ -1640,9 +1851,9 @@ def test_link(self): iter(self.foo2.my_data) # re-open the file with the data linking to other file (still closed) - with HDF5IO(self.paths[1], manager=get_foo_buildmanager(), mode='r') as io: + with HDF5IO(self.paths[1], manager=get_foo_buildmanager(), mode="r") as io: read_foofile2 = io.read() - read_foo2 = read_foofile2.buckets['bucket2'].foos['foo2'] + read_foo2 = read_foofile2.buckets["bucket2"].foos["foo2"] # note that read_foo2 dataset does not have an attribute 'valid' self.assertEqual(len(read_foo2.my_data), 5) # test len @@ -1651,7 +1862,7 @@ def test_link(self): # test loop through iterable match = [1, 2, 3, 4, 5] - for (i, j) in zip(read_foo2.my_data, match): + for i, j in zip(read_foo2.my_data, match): self.assertEqual(i, j) # test iterator @@ -1660,27 +1871,26 @@ def test_link(self): class TestReadLink(TestCase): - def setUp(self): self.target_path = get_temp_filepath() self.link_path = get_temp_filepath() - root1 = GroupBuilder(name='root') - subgroup = GroupBuilder(name='test_group') + root1 = GroupBuilder(name="root") + subgroup = GroupBuilder(name="test_group") root1.set_group(subgroup) - dataset = DatasetBuilder('test_dataset', data=[1, 2, 3, 4]) + dataset = DatasetBuilder("test_dataset", data=[1, 2, 3, 4]) subgroup.set_dataset(dataset) - root2 = GroupBuilder(name='root') - link_group = LinkBuilder(subgroup, 'link_to_test_group') + root2 = GroupBuilder(name="root") + link_group = LinkBuilder(subgroup, "link_to_test_group") root2.set_link(link_group) - link_dataset = LinkBuilder(dataset, 'link_to_test_dataset') + link_dataset = LinkBuilder(dataset, "link_to_test_dataset") root2.set_link(link_dataset) - with HDF5IO(self.target_path, manager=get_foo_buildmanager(), mode='w') as io: + with HDF5IO(self.target_path, manager=get_foo_buildmanager(), mode="w") as io: io.write_builder(root1) root1.source = self.target_path - with HDF5IO(self.link_path, manager=get_foo_buildmanager(), mode='w') as io: + with HDF5IO(self.link_path, manager=get_foo_buildmanager(), mode="w") as io: io.write_builder(root2) root2.source = self.link_path self.ios = [] @@ -1695,11 +1905,11 @@ def test_set_link_loc(self): """ Test that Builder location is set when it is read as a link """ - read_io = HDF5IO(self.link_path, manager=get_foo_buildmanager(), mode='r') + read_io = HDF5IO(self.link_path, manager=get_foo_buildmanager(), mode="r") self.ios.append(read_io) # store IO object for closing in tearDown bldr = read_io.read_builder() - self.assertEqual(bldr['link_to_test_group'].builder.location, '/') - self.assertEqual(bldr['link_to_test_dataset'].builder.location, '/test_group') + self.assertEqual(bldr["link_to_test_group"].builder.location, "/") + self.assertEqual(bldr["link_to_test_dataset"].builder.location, "/test_group") read_io.close() def test_link_to_link(self): @@ -1707,29 +1917,32 @@ def test_link_to_link(self): Test that link to link gets written and read properly """ link_to_link_path = get_temp_filepath() - read_io1 = HDF5IO(self.link_path, manager=get_foo_buildmanager(), mode='r') + read_io1 = HDF5IO(self.link_path, manager=get_foo_buildmanager(), mode="r") self.ios.append(read_io1) # store IO object for closing in tearDown bldr1 = read_io1.read_builder() - root3 = GroupBuilder(name='root') + root3 = GroupBuilder(name="root") - link = LinkBuilder(bldr1['link_to_test_group'].builder, 'link_to_link') + link = LinkBuilder(bldr1["link_to_test_group"].builder, "link_to_link") root3.set_link(link) - root3.set_link(LinkBuilder(bldr1['link_to_test_group'].builder, 'link_to_link')) - with HDF5IO(link_to_link_path, manager=get_foo_buildmanager(), mode='w') as io: + root3.set_link(LinkBuilder(bldr1["link_to_test_group"].builder, "link_to_link")) + with HDF5IO(link_to_link_path, manager=get_foo_buildmanager(), mode="w") as io: io.write_builder(root3) read_io1.close() - read_io2 = HDF5IO(link_to_link_path, manager=get_foo_buildmanager(), mode='r') + read_io2 = HDF5IO(link_to_link_path, manager=get_foo_buildmanager(), mode="r") self.ios.append(read_io2) bldr2 = read_io2.read_builder() - self.assertEqual(bldr2['link_to_link'].builder.source, self.target_path) + self.assertEqual(bldr2["link_to_link"].builder.source, self.target_path) read_io2.close() def test_broken_link(self): """Test that opening a file with a broken link raises a warning but is still readable.""" os.remove(self.target_path) - with self.assertWarnsWith(BrokenLinkWarning, 'Path to Group altered/broken at /link_to_test_group'): - with HDF5IO(self.link_path, manager=get_foo_buildmanager(), mode='r') as read_io: + with self.assertWarnsWith( + BrokenLinkWarning, + "Path to Group altered/broken at /link_to_test_group", + ): + with HDF5IO(self.link_path, manager=get_foo_buildmanager(), mode="r") as read_io: bldr = read_io.read_builder() self.assertDictEqual(bldr.links, {}) @@ -1737,30 +1950,32 @@ def test_broken_linked_data(self): """Test that opening a file with a broken link raises a warning but is still readable.""" manager = get_foo_buildmanager() - with HDF5IO(self.target_path, manager=manager, mode='r') as read_io: + with HDF5IO(self.target_path, manager=manager, mode="r") as read_io: read_root = read_io.read_builder() - read_dataset_data = read_root.groups['test_group'].datasets['test_dataset'].data + read_dataset_data = read_root.groups["test_group"].datasets["test_dataset"].data - with HDF5IO(self.link_path, manager=manager, mode='w') as write_io: - root2 = GroupBuilder(name='root') - dataset = DatasetBuilder(name='link_to_test_dataset', data=read_dataset_data) + with HDF5IO(self.link_path, manager=manager, mode="w") as write_io: + root2 = GroupBuilder(name="root") + dataset = DatasetBuilder(name="link_to_test_dataset", data=read_dataset_data) root2.set_dataset(dataset) write_io.write_builder(root2, link_data=True) os.remove(self.target_path) - with self.assertWarnsWith(BrokenLinkWarning, 'Path to Group altered/broken at /link_to_test_dataset'): - with HDF5IO(self.link_path, manager=get_foo_buildmanager(), mode='r') as read_io: + with self.assertWarnsWith( + BrokenLinkWarning, + "Path to Group altered/broken at /link_to_test_dataset", + ): + with HDF5IO(self.link_path, manager=get_foo_buildmanager(), mode="r") as read_io: bldr = read_io.read_builder() self.assertDictEqual(bldr.links, {}) class TestBuildWriteLinkToLink(TestCase): - def setUp(self): self.paths = [ get_temp_filepath(), get_temp_filepath(), - get_temp_filepath() + get_temp_filepath(), ] self.ios = [] @@ -1771,32 +1986,32 @@ def tearDown(self): def test_external_link_to_external_link(self): """Test writing a file with external links to external links.""" - foo1 = Foo('foo1', [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) - foobucket = FooBucket('bucket1', [foo1]) + foo1 = Foo("foo1", [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) + foobucket = FooBucket("bucket1", [foo1]) foofile = FooFile(buckets=[foobucket]) - with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode='w') as write_io: + with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode="w") as write_io: write_io.write(foofile) manager = get_foo_buildmanager() - with HDF5IO(self.paths[0], manager=manager, mode='r') as read_io: + with HDF5IO(self.paths[0], manager=manager, mode="r") as read_io: read_foofile = read_io.read() # make external link to existing group - foofile2 = FooFile(foo_link=read_foofile.buckets['bucket1'].foos['foo1']) + foofile2 = FooFile(foo_link=read_foofile.buckets["bucket1"].foos["foo1"]) - with HDF5IO(self.paths[1], manager=manager, mode='w') as write_io: + with HDF5IO(self.paths[1], manager=manager, mode="w") as write_io: write_io.write(foofile2) manager = get_foo_buildmanager() - with HDF5IO(self.paths[1], manager=manager, mode='r') as read_io: + with HDF5IO(self.paths[1], manager=manager, mode="r") as read_io: self.ios.append(read_io) # track IO objects for tearDown read_foofile2 = read_io.read() foofile3 = FooFile(foo_link=read_foofile2.foo_link) # make external link to external link - with HDF5IO(self.paths[2], manager=manager, mode='w') as write_io: + with HDF5IO(self.paths[2], manager=manager, mode="w") as write_io: write_io.write(foofile3) - with HDF5IO(self.paths[2], manager=get_foo_buildmanager(), mode='r') as read_io: + with HDF5IO(self.paths[2], manager=get_foo_buildmanager(), mode="r") as read_io: self.ios.append(read_io) # track IO objects for tearDown read_foofile3 = read_io.read() @@ -1804,31 +2019,31 @@ def test_external_link_to_external_link(self): def test_external_link_to_soft_link(self): """Test writing a file with external links to external links.""" - foo1 = Foo('foo1', [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) - foobucket = FooBucket('bucket1', [foo1]) + foo1 = Foo("foo1", [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) + foobucket = FooBucket("bucket1", [foo1]) foofile = FooFile(buckets=[foobucket], foo_link=foo1) # create soft link - with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode='w') as write_io: + with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode="w") as write_io: write_io.write(foofile) manager = get_foo_buildmanager() - with HDF5IO(self.paths[0], manager=manager, mode='r') as read_io: + with HDF5IO(self.paths[0], manager=manager, mode="r") as read_io: read_foofile = read_io.read() foofile2 = FooFile(foo_link=read_foofile.foo_link) # make external link to existing soft link - with HDF5IO(self.paths[1], manager=manager, mode='w') as write_io: + with HDF5IO(self.paths[1], manager=manager, mode="w") as write_io: write_io.write(foofile2) manager = get_foo_buildmanager() - with HDF5IO(self.paths[1], manager=manager, mode='r') as read_io: + with HDF5IO(self.paths[1], manager=manager, mode="r") as read_io: self.ios.append(read_io) # track IO objects for tearDown read_foofile2 = read_io.read() foofile3 = FooFile(foo_link=read_foofile2.foo_link) # make external link to external link - with HDF5IO(self.paths[2], manager=manager, mode='w') as write_io: + with HDF5IO(self.paths[2], manager=manager, mode="w") as write_io: write_io.write(foofile3) - with HDF5IO(self.paths[2], manager=get_foo_buildmanager(), mode='r') as read_io: + with HDF5IO(self.paths[2], manager=get_foo_buildmanager(), mode="r") as read_io: self.ios.append(read_io) # track IO objects for tearDown read_foofile3 = read_io.read() @@ -1836,17 +2051,16 @@ def test_external_link_to_soft_link(self): class TestLinkData(TestCase): - def setUp(self): self.target_path = get_temp_filepath() self.link_path = get_temp_filepath() - root1 = GroupBuilder(name='root') - subgroup = GroupBuilder(name='test_group') + root1 = GroupBuilder(name="root") + subgroup = GroupBuilder(name="test_group") root1.set_group(subgroup) - dataset = DatasetBuilder('test_dataset', data=[1, 2, 3, 4]) + dataset = DatasetBuilder("test_dataset", data=[1, 2, 3, 4]) subgroup.set_dataset(dataset) - with HDF5IO(self.target_path, manager=get_foo_buildmanager(), mode='w') as io: + with HDF5IO(self.target_path, manager=get_foo_buildmanager(), mode="w") as io: io.write_builder(root1) def tearDown(self): @@ -1858,44 +2072,43 @@ def tearDown(self): def test_link_data_true(self): """Test that the argument link_data=True for write_builder creates an external link.""" manager = get_foo_buildmanager() - with HDF5IO(self.target_path, manager=manager, mode='r') as read_io: + with HDF5IO(self.target_path, manager=manager, mode="r") as read_io: read_root = read_io.read_builder() - read_dataset_data = read_root.groups['test_group'].datasets['test_dataset'].data + read_dataset_data = read_root.groups["test_group"].datasets["test_dataset"].data - with HDF5IO(self.link_path, manager=manager, mode='w') as write_io: - root2 = GroupBuilder(name='root') - dataset = DatasetBuilder(name='link_to_test_dataset', data=read_dataset_data) + with HDF5IO(self.link_path, manager=manager, mode="w") as write_io: + root2 = GroupBuilder(name="root") + dataset = DatasetBuilder(name="link_to_test_dataset", data=read_dataset_data) root2.set_dataset(dataset) write_io.write_builder(root2, link_data=True) - with File(self.link_path, mode='r') as f: - self.assertIsInstance(f.get('link_to_test_dataset', getlink=True), ExternalLink) + with File(self.link_path, mode="r") as f: + self.assertIsInstance(f.get("link_to_test_dataset", getlink=True), ExternalLink) def test_link_data_false(self): """Test that the argument link_data=False for write_builder copies the data.""" manager = get_foo_buildmanager() - with HDF5IO(self.target_path, manager=manager, mode='r') as read_io: + with HDF5IO(self.target_path, manager=manager, mode="r") as read_io: read_root = read_io.read_builder() - read_dataset_data = read_root.groups['test_group'].datasets['test_dataset'].data + read_dataset_data = read_root.groups["test_group"].datasets["test_dataset"].data - with HDF5IO(self.link_path, manager=manager, mode='w') as write_io: - root2 = GroupBuilder(name='root') - dataset = DatasetBuilder(name='link_to_test_dataset', data=read_dataset_data) + with HDF5IO(self.link_path, manager=manager, mode="w") as write_io: + root2 = GroupBuilder(name="root") + dataset = DatasetBuilder(name="link_to_test_dataset", data=read_dataset_data) root2.set_dataset(dataset) write_io.write_builder(root2, link_data=False) - with File(self.link_path, mode='r') as f: - self.assertFalse(isinstance(f.get('link_to_test_dataset', getlink=True), ExternalLink)) - self.assertListEqual(f.get('link_to_test_dataset')[:].tolist(), [1, 2, 3, 4]) + with File(self.link_path, mode="r") as f: + self.assertFalse(isinstance(f.get("link_to_test_dataset", getlink=True), ExternalLink)) + self.assertListEqual(f.get("link_to_test_dataset")[:].tolist(), [1, 2, 3, 4]) class TestLoadNamespaces(TestCase): - def setUp(self): self.manager = get_foo_buildmanager() self.path = get_temp_filepath() container = FooFile() - with HDF5IO(self.path, manager=self.manager, mode='w') as io: + with HDF5IO(self.path, manager=self.manager, mode="w") as io: io.write(container) def tearDown(self): @@ -1906,14 +2119,20 @@ def test_load_namespaces_none_version(self): """Test that reading a file with a cached namespace and None version works but raises a warning.""" # make the file have group name "None" instead of "0.1.0" (namespace version is used as group name) # and set the version key to "None" - with h5py.File(self.path, mode='r+') as f: + with h5py.File(self.path, mode="r+") as f: # rename the group - f.move('/specifications/test_core/0.1.0', '/specifications/test_core/None') + f.move( + "/specifications/test_core/0.1.0", + "/specifications/test_core/None", + ) # replace the namespace dataset with a serialized dict with the version key set to 'None' - new_ns = ('{"namespaces":[{"doc":"a test namespace","schema":[{"source":"test"}],"name":"test_core",' - '"version":"None"}]}') - f['/specifications/test_core/None/namespace'][()] = new_ns + new_ns = ( + '{"namespaces":[{"doc":"a test' + ' namespace","schema":[{"source":"test"}],"name":"test_core",' + '"version":"None"}]}' + ) + f["/specifications/test_core/None/namespace"][()] = new_ns # load the namespace from file ns_catalog = NamespaceCatalog() @@ -1925,18 +2144,23 @@ def test_load_namespaces_unversioned(self): """Test that reading a file with a cached, unversioned version works but raises a warning.""" # make the file have group name "unversioned" instead of "0.1.0" (namespace version is used as group name) # and remove the version key - with h5py.File(self.path, mode='r+') as f: + with h5py.File(self.path, mode="r+") as f: # rename the group - f.move('/specifications/test_core/0.1.0', '/specifications/test_core/unversioned') + f.move( + "/specifications/test_core/0.1.0", + "/specifications/test_core/unversioned", + ) # replace the namespace dataset with a serialized dict without the version key - new_ns = ('{"namespaces":[{"doc":"a test namespace","schema":[{"source":"test"}],"name":"test_core"}]}') - f['/specifications/test_core/unversioned/namespace'][()] = new_ns + new_ns = '{"namespaces":[{"doc":"a test namespace","schema":[{"source":"test"}],"name":"test_core"}]}' + f["/specifications/test_core/unversioned/namespace"][()] = new_ns # load the namespace from file ns_catalog = NamespaceCatalog() - msg = ("Loaded namespace '%s' is missing the required key 'version'. Version will be set to " - "'%s'. Please notify the extension author." % (CORE_NAMESPACE, SpecNamespace.UNVERSIONED)) + msg = ( + "Loaded namespace '%s' is missing the required key 'version'. Version will" + " be set to '%s'. Please notify the extension author." % (CORE_NAMESPACE, SpecNamespace.UNVERSIONED) + ) with self.assertWarnsWith(UserWarning, msg): HDF5IO.load_namespaces(ns_catalog, self.path) @@ -1944,7 +2168,7 @@ def test_load_namespaces_path(self): """Test that loading namespaces given a path is OK and returns the correct dictionary.""" ns_catalog = NamespaceCatalog() d = HDF5IO.load_namespaces(ns_catalog, self.path) - self.assertEqual(d, {'test_core': {}}) # test_core has no dependencies + self.assertEqual(d, {"test_core": {}}) # test_core has no dependencies def test_load_namespaces_no_path_no_file(self): """Test that loading namespaces without a path or file raises an error.""" @@ -1958,76 +2182,83 @@ def test_load_namespaces_file_no_path(self): """ Test that loading namespaces from an h5py.File not backed by a file on disk is OK and does not close the file. """ - with open(self.path, 'rb') as raw_file: + with open(self.path, "rb") as raw_file: buffer = BytesIO(raw_file.read()) - file_obj = h5py.File(buffer, 'r') + file_obj = h5py.File(buffer, "r") ns_catalog = NamespaceCatalog() d = HDF5IO.load_namespaces(ns_catalog, file=file_obj) self.assertTrue(file_obj.__bool__()) # check file object is still open - self.assertEqual(d, {'test_core': {}}) + self.assertEqual(d, {"test_core": {}}) file_obj.close() def test_load_namespaces_file_path_matched(self): """Test that loading namespaces given an h5py.File and path is OK and does not close the file.""" - with h5py.File(self.path, 'r') as file_obj: + with h5py.File(self.path, "r") as file_obj: ns_catalog = NamespaceCatalog() d = HDF5IO.load_namespaces(ns_catalog, path=self.path, file=file_obj) self.assertTrue(file_obj.__bool__()) # check file object is still open - self.assertEqual(d, {'test_core': {}}) + self.assertEqual(d, {"test_core": {}}) def test_load_namespaces_file_path_mismatched(self): """Test that loading namespaces given an h5py.File and path that are mismatched raises an error.""" - with h5py.File(self.path, 'r') as file_obj: + with h5py.File(self.path, "r") as file_obj: ns_catalog = NamespaceCatalog() msg = "You argued 'different_path' as this object's path, but supplied a file with filename: %s" % self.path with self.assertRaisesWith(ValueError, msg): - HDF5IO.load_namespaces(ns_catalog, path='different_path', file=file_obj) + HDF5IO.load_namespaces(ns_catalog, path="different_path", file=file_obj) def test_load_namespaces_with_pathlib_path(self): """Test that loading a namespace using a valid pathlib Path is OK and returns the correct dictionary.""" pathlib_path = Path(self.path) ns_catalog = NamespaceCatalog() d = HDF5IO.load_namespaces(ns_catalog, pathlib_path) - self.assertEqual(d, {'test_core': {}}) # test_core has no dependencies + self.assertEqual(d, {"test_core": {}}) # test_core has no dependencies def test_load_namespaces_with_dependencies(self): """Test loading namespaces where one includes another.""" + class MyFoo(Container): pass - myfoo_spec = GroupSpec(doc="A MyFoo", data_type_def='MyFoo', data_type_inc='Foo') + myfoo_spec = GroupSpec(doc="A MyFoo", data_type_def="MyFoo", data_type_inc="Foo") spec_catalog = SpecCatalog() - name = 'test_core2' + name = "test_core2" namespace = SpecNamespace( - doc='a test namespace', + doc="a test namespace", name=name, - schema=[{'source': 'test2.yaml', 'namespace': 'test_core'}], # depends on test_core - version='0.1.0', - catalog=spec_catalog + schema=[{"source": "test2.yaml", "namespace": "test_core"}], # depends on test_core + version="0.1.0", + catalog=spec_catalog, ) - spec_catalog.register_spec(myfoo_spec, 'test2.yaml') + spec_catalog.register_spec(myfoo_spec, "test2.yaml") namespace_catalog = NamespaceCatalog() namespace_catalog.add_namespace(name, namespace) type_map = TypeMap(namespace_catalog) - type_map.register_container_type(name, 'MyFoo', MyFoo) + type_map.register_container_type(name, "MyFoo", MyFoo) type_map.merge(self.manager.type_map, ns_catalog=True) manager = BuildManager(type_map) - container = MyFoo(name='myfoo') - with HDF5IO(self.path, manager=manager, mode='a') as io: # append to file + container = MyFoo(name="myfoo") + with HDF5IO(self.path, manager=manager, mode="a") as io: # append to file io.write(container) ns_catalog = NamespaceCatalog() d = HDF5IO.load_namespaces(ns_catalog, self.path) - self.assertEqual(d, {'test_core': {}, 'test_core2': {'test_core': ('Foo', 'FooBucket', 'FooFile')}}) + self.assertEqual( + d, + { + "test_core": {}, + "test_core2": {"test_core": ("Foo", "FooBucket", "FooFile")}, + }, + ) def test_load_namespaces_no_specloc(self): """Test loading namespaces where the file does not contain a SPEC_LOC_ATTR.""" # delete the spec location attribute from the file - with h5py.File(self.path, mode='r+') as f: + with h5py.File(self.path, mode="r+") as f: del f.attrs[SPEC_LOC_ATTR] # load the namespace from file @@ -2040,74 +2271,80 @@ def test_load_namespaces_no_specloc(self): def test_load_namespaces_resolve_custom_deps(self): """Test that reading a file with a cached namespace and different def/inc keys works.""" # Setup all the data we need - foo1 = Foo('foo1', [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) - foobucket = FooBucket('bucket1', [foo1]) + foo1 = Foo("foo1", [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) + foobucket = FooBucket("bucket1", [foo1]) foofile = FooFile(buckets=[foobucket]) - with HDF5IO(self.path, manager=self.manager, mode='w') as io: + with HDF5IO(self.path, manager=self.manager, mode="w") as io: io.write(foofile) - with h5py.File(self.path, mode='r+') as f: + with h5py.File(self.path, mode="r+") as f: # add two types where one extends the other and overrides an attribute # check that the inherited attribute resolves correctly despite having a different def/inc key than those # used in the namespace catalog - added_types = (',{"data_type_def":"BigFoo","data_type_inc":"Foo","doc":"doc","attributes":[' - '{"name":"my_attr","dtype":"text","doc":"an attr"}]},' - '{"data_type_def":"BiggerFoo","data_type_inc":"BigFoo","doc":"doc"}]}') - old_test_source = f['/specifications/test_core/0.1.0/test'] + added_types = ( + ',{"data_type_def":"BigFoo","data_type_inc":"Foo","doc":"doc","attributes":[' + '{"name":"my_attr","dtype":"text","doc":"an attr"}]},' + '{"data_type_def":"BiggerFoo","data_type_inc":"BigFoo","doc":"doc"}]}' + ) + old_test_source = f["/specifications/test_core/0.1.0/test"] # strip the ]} from end, then add to groups if H5PY_3: # string datasets are returned as bytes - old_test_source[()] = old_test_source[()][0:-2].decode('utf-8') + added_types + old_test_source[()] = old_test_source[()][0:-2].decode("utf-8") + added_types else: old_test_source[()] = old_test_source[()][0:-2] + added_types - new_ns = ('{"namespaces":[{"doc":"a test namespace","schema":[' - '{"namespace":"test_core","my_data_types":["Foo"]},' - '{"source":"test-ext.extensions"}' - '],"name":"test-ext","version":"0.1.0"}]}') - f.create_dataset('/specifications/test-ext/0.1.0/namespace', data=new_ns) + new_ns = ( + '{"namespaces":[{"doc":"a test namespace","schema":[' + '{"namespace":"test_core","my_data_types":["Foo"]},' + '{"source":"test-ext.extensions"}' + '],"name":"test-ext","version":"0.1.0"}]}' + ) + f.create_dataset("/specifications/test-ext/0.1.0/namespace", data=new_ns) new_ext = '{"groups":[{"my_data_type_def":"FooExt","my_data_type_inc":"Foo","doc":"doc"}]}' - f.create_dataset('/specifications/test-ext/0.1.0/test-ext.extensions', data=new_ext) + f.create_dataset( + "/specifications/test-ext/0.1.0/test-ext.extensions", + data=new_ext, + ) # load the namespace from file ns_catalog = NamespaceCatalog(CustomGroupSpec, CustomDatasetSpec, CustomSpecNamespace) namespace_deps = HDF5IO.load_namespaces(ns_catalog, self.path) # test that the dependencies are correct - expected = ('Foo',) - self.assertTupleEqual((namespace_deps['test-ext']['test_core']), expected) + expected = ("Foo",) + self.assertTupleEqual((namespace_deps["test-ext"]["test_core"]), expected) # test that the types are loaded - types = ns_catalog.get_types('test-ext.extensions') - expected = ('FooExt',) + types = ns_catalog.get_types("test-ext.extensions") + expected = ("FooExt",) self.assertTupleEqual(types, expected) # test that the def_key is updated for test-ext ns - foo_ext_spec = ns_catalog.get_spec('test-ext', 'FooExt') - self.assertTrue('my_data_type_def' in foo_ext_spec) - self.assertTrue('my_data_type_inc' in foo_ext_spec) + foo_ext_spec = ns_catalog.get_spec("test-ext", "FooExt") + self.assertTrue("my_data_type_def" in foo_ext_spec) + self.assertTrue("my_data_type_inc" in foo_ext_spec) # test that the data_type_def is replaced with my_data_type_def for test_core ns - bigger_foo_spec = ns_catalog.get_spec('test_core', 'BiggerFoo') - self.assertTrue('my_data_type_def' in bigger_foo_spec) - self.assertTrue('my_data_type_inc' in bigger_foo_spec) + bigger_foo_spec = ns_catalog.get_spec("test_core", "BiggerFoo") + self.assertTrue("my_data_type_def" in bigger_foo_spec) + self.assertTrue("my_data_type_inc" in bigger_foo_spec) # test that my_attr is properly inherited in BiggerFoo from BigFoo and attr1, attr3 are inherited from Foo self.assertTrue(len(bigger_foo_spec.attributes) == 3) class TestGetNamespaces(TestCase): - def create_test_namespace(self, name, version): - file_spec = GroupSpec(doc="A FooFile", data_type_def='FooFile') + file_spec = GroupSpec(doc="A FooFile", data_type_def="FooFile") spec_catalog = SpecCatalog() namespace = SpecNamespace( - doc='a test namespace', + doc="a test namespace", name=name, - schema=[{'source': 'test.yaml'}], + schema=[{"source": "test.yaml"}], version=version, - catalog=spec_catalog + catalog=spec_catalog, ) - spec_catalog.register_spec(file_spec, 'test.yaml') + spec_catalog.register_spec(file_spec, "test.yaml") return namespace def write_test_file(self, name, version, mode): @@ -2115,7 +2352,7 @@ def write_test_file(self, name, version, mode): namespace_catalog = NamespaceCatalog() namespace_catalog.add_namespace(name, namespace) type_map = TypeMap(namespace_catalog) - type_map.register_container_type(name, 'FooFile', FooFile) + type_map.register_container_type(name, "FooFile", FooFile) manager = BuildManager(type_map) with HDF5IO(self.path, manager=manager, mode=mode) as io: io.write(self.container) @@ -2132,119 +2369,137 @@ def tearDown(self): def test_get_namespaces_with_path(self): """Test getting namespaces given a path.""" - self.write_test_file('test_core', '0.1.0', 'w') + self.write_test_file("test_core", "0.1.0", "w") ret = HDF5IO.get_namespaces(path=self.path) - self.assertEqual(ret, {'test_core': '0.1.0'}) + self.assertEqual(ret, {"test_core": "0.1.0"}) def test_get_namespaces_with_file(self): """Test getting namespaces given a file object.""" - self.write_test_file('test_core', '0.1.0', 'w') + self.write_test_file("test_core", "0.1.0", "w") - with File(self.path, 'r') as f: + with File(self.path, "r") as f: ret = HDF5IO.get_namespaces(file=f) - self.assertEqual(ret, {'test_core': '0.1.0'}) + self.assertEqual(ret, {"test_core": "0.1.0"}) self.assertTrue(f.__bool__()) # check file object is still open def test_get_namespaces_different_versions(self): """Test getting namespaces with multiple versions given a path.""" # write file with spec with smaller version string - self.write_test_file('test_core', '0.0.10', 'w') + self.write_test_file("test_core", "0.0.10", "w") # append to file with spec with larger version string - self.write_test_file('test_core', '0.1.0', 'a') + self.write_test_file("test_core", "0.1.0", "a") ret = HDF5IO.get_namespaces(path=self.path) - self.assertEqual(ret, {'test_core': '0.1.0'}) + self.assertEqual(ret, {"test_core": "0.1.0"}) def test_get_namespaces_multiple_namespaces(self): """Test getting multiple namespaces given a path.""" - self.write_test_file('test_core1', '0.0.10', 'w') - self.write_test_file('test_core2', '0.1.0', 'a') + self.write_test_file("test_core1", "0.0.10", "w") + self.write_test_file("test_core2", "0.1.0", "a") ret = HDF5IO.get_namespaces(path=self.path) - self.assertEqual(ret, {'test_core1': '0.0.10', 'test_core2': '0.1.0'}) + self.assertEqual(ret, {"test_core1": "0.0.10", "test_core2": "0.1.0"}) def test_get_namespaces_none_version(self): """Test getting namespaces where file has one None-versioned namespace.""" - self.write_test_file('test_core', '0.1.0', 'w') + self.write_test_file("test_core", "0.1.0", "w") # make the file have group name "None" instead of "0.1.0" (namespace version is used as group name) # and set the version key to "None" - with h5py.File(self.path, mode='r+') as f: + with h5py.File(self.path, mode="r+") as f: # rename the group - f.move('/specifications/test_core/0.1.0', '/specifications/test_core/None') + f.move( + "/specifications/test_core/0.1.0", + "/specifications/test_core/None", + ) # replace the namespace dataset with a serialized dict with the version key set to 'None' - new_ns = ('{"namespaces":[{"doc":"a test namespace","schema":[{"source":"test"}],"name":"test_core",' - '"version":"None"}]}') - f['/specifications/test_core/None/namespace'][()] = new_ns + new_ns = ( + '{"namespaces":[{"doc":"a test' + ' namespace","schema":[{"source":"test"}],"name":"test_core",' + '"version":"None"}]}' + ) + f["/specifications/test_core/None/namespace"][()] = new_ns ret = HDF5IO.get_namespaces(path=self.path) - self.assertEqual(ret, {'test_core': 'None'}) + self.assertEqual(ret, {"test_core": "None"}) def test_get_namespaces_none_and_other_version(self): """Test getting namespaces file has a namespace with a normal version and an 'None" version.""" - self.write_test_file('test_core', '0.1.0', 'w') + self.write_test_file("test_core", "0.1.0", "w") # make the file have group name "None" instead of "0.1.0" (namespace version is used as group name) # and set the version key to "None" - with h5py.File(self.path, mode='r+') as f: + with h5py.File(self.path, mode="r+") as f: # rename the group - f.move('/specifications/test_core/0.1.0', '/specifications/test_core/None') + f.move( + "/specifications/test_core/0.1.0", + "/specifications/test_core/None", + ) # replace the namespace dataset with a serialized dict with the version key set to 'None' - new_ns = ('{"namespaces":[{"doc":"a test namespace","schema":[{"source":"test"}],"name":"test_core",' - '"version":"None"}]}') - f['/specifications/test_core/None/namespace'][()] = new_ns + new_ns = ( + '{"namespaces":[{"doc":"a test' + ' namespace","schema":[{"source":"test"}],"name":"test_core",' + '"version":"None"}]}' + ) + f["/specifications/test_core/None/namespace"][()] = new_ns # append to file with spec with a larger version string - self.write_test_file('test_core', '0.2.0', 'a') + self.write_test_file("test_core", "0.2.0", "a") ret = HDF5IO.get_namespaces(path=self.path) - self.assertEqual(ret, {'test_core': '0.2.0'}) + self.assertEqual(ret, {"test_core": "0.2.0"}) def test_get_namespaces_unversioned(self): """Test getting namespaces where file has one unversioned namespace.""" - self.write_test_file('test_core', '0.1.0', 'w') + self.write_test_file("test_core", "0.1.0", "w") # make the file have group name "unversioned" instead of "0.1.0" (namespace version is used as group name) - with h5py.File(self.path, mode='r+') as f: + with h5py.File(self.path, mode="r+") as f: # rename the group - f.move('/specifications/test_core/0.1.0', '/specifications/test_core/unversioned') + f.move( + "/specifications/test_core/0.1.0", + "/specifications/test_core/unversioned", + ) # replace the namespace dataset with a serialized dict without the version key - new_ns = ('{"namespaces":[{"doc":"a test namespace","schema":[{"source":"test"}],"name":"test_core"}]}') - f['/specifications/test_core/unversioned/namespace'][()] = new_ns + new_ns = '{"namespaces":[{"doc":"a test namespace","schema":[{"source":"test"}],"name":"test_core"}]}' + f["/specifications/test_core/unversioned/namespace"][()] = new_ns ret = HDF5IO.get_namespaces(path=self.path) - self.assertEqual(ret, {'test_core': 'unversioned'}) + self.assertEqual(ret, {"test_core": "unversioned"}) def test_get_namespaces_unversioned_and_other(self): """Test getting namespaces file has a namespace with a normal version and an 'unversioned" version.""" - self.write_test_file('test_core', '0.1.0', 'w') + self.write_test_file("test_core", "0.1.0", "w") # make the file have group name "unversioned" instead of "0.1.0" (namespace version is used as group name) - with h5py.File(self.path, mode='r+') as f: + with h5py.File(self.path, mode="r+") as f: # rename the group - f.move('/specifications/test_core/0.1.0', '/specifications/test_core/unversioned') + f.move( + "/specifications/test_core/0.1.0", + "/specifications/test_core/unversioned", + ) # replace the namespace dataset with a serialized dict without the version key - new_ns = ('{"namespaces":[{"doc":"a test namespace","schema":[{"source":"test"}],"name":"test_core"}]}') - f['/specifications/test_core/unversioned/namespace'][()] = new_ns + new_ns = '{"namespaces":[{"doc":"a test namespace","schema":[{"source":"test"}],"name":"test_core"}]}' + f["/specifications/test_core/unversioned/namespace"][()] = new_ns # append to file with spec with a larger version string - self.write_test_file('test_core', '0.2.0', 'a') + self.write_test_file("test_core", "0.2.0", "a") ret = HDF5IO.get_namespaces(path=self.path) - self.assertEqual(ret, {'test_core': '0.2.0'}) + self.assertEqual(ret, {"test_core": "0.2.0"}) def test_get_namespaces_no_specloc(self): """Test getting namespaces where the file does not contain a SPEC_LOC_ATTR.""" - self.write_test_file('test_core', '0.1.0', 'w') + self.write_test_file("test_core", "0.1.0", "w") # delete the spec location attribute from the file - with h5py.File(self.path, mode='r+') as f: + with h5py.File(self.path, mode="r+") as f: del f.attrs[SPEC_LOC_ATTR] # load the namespace from file @@ -2273,80 +2528,85 @@ def tearDown(self): def test_basic(self): """Test that exporting a written container works.""" - foo1 = Foo('foo1', [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) - foobucket = FooBucket('bucket1', [foo1]) + foo1 = Foo("foo1", [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) + foobucket = FooBucket("bucket1", [foo1]) foofile = FooFile(buckets=[foobucket]) - with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode='w') as write_io: + with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode="w") as write_io: write_io.write(foofile) - with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode='r') as read_io: - with HDF5IO(self.paths[1], mode='w') as export_io: + with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode="r") as read_io: + with HDF5IO(self.paths[1], mode="w") as export_io: export_io.export(src_io=read_io) self.assertTrue(os.path.exists(self.paths[1])) self.assertEqual(foofile.container_source, self.paths[0]) - with HDF5IO(self.paths[1], manager=get_foo_buildmanager(), mode='r') as read_io: + with HDF5IO(self.paths[1], manager=get_foo_buildmanager(), mode="r") as read_io: read_foofile = read_io.read() self.assertEqual(read_foofile.container_source, self.paths[1]) self.assertContainerEqual(foofile, read_foofile, ignore_hdmf_attrs=True) - self.assertEqual(os.path.abspath(read_foofile.buckets['bucket1'].foos['foo1'].my_data.file.filename), - self.paths[1]) + self.assertEqual( + os.path.abspath(read_foofile.buckets["bucket1"].foos["foo1"].my_data.file.filename), + self.paths[1], + ) def test_basic_container(self): """Test that exporting a written container, passing in the container arg, works.""" - foo1 = Foo('foo1', [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) - foobucket = FooBucket('bucket1', [foo1]) + foo1 = Foo("foo1", [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) + foobucket = FooBucket("bucket1", [foo1]) foofile = FooFile(buckets=[foobucket]) - with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode='w') as write_io: + with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode="w") as write_io: write_io.write(foofile) - with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode='r') as read_io: + with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode="r") as read_io: read_foofile = read_io.read() - with HDF5IO(self.paths[1], mode='w') as export_io: + with HDF5IO(self.paths[1], mode="w") as export_io: export_io.export(src_io=read_io, container=read_foofile) self.assertTrue(os.path.exists(self.paths[1])) self.assertEqual(foofile.container_source, self.paths[0]) - with HDF5IO(self.paths[1], manager=get_foo_buildmanager(), mode='r') as read_io: + with HDF5IO(self.paths[1], manager=get_foo_buildmanager(), mode="r") as read_io: read_foofile = read_io.read() self.assertEqual(read_foofile.container_source, self.paths[1]) self.assertContainerEqual(foofile, read_foofile, ignore_hdmf_attrs=True) def test_container_part(self): """Test that exporting a part of a written container raises an error.""" - foo1 = Foo('foo1', [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) - foobucket = FooBucket('bucket1', [foo1]) + foo1 = Foo("foo1", [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) + foobucket = FooBucket("bucket1", [foo1]) foofile = FooFile(buckets=[foobucket]) - with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode='w') as write_io: + with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode="w") as write_io: write_io.write(foofile) - with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode='r') as read_io: + with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode="r") as read_io: read_foofile = read_io.read() - with HDF5IO(self.paths[1], mode='w') as export_io: - msg = ("The provided container must be the root of the hierarchy of the source used to read the " - "container.") + with HDF5IO(self.paths[1], mode="w") as export_io: + msg = ( + "The provided container must be the root of the hierarchy of the source used to read the container." + ) with self.assertRaisesWith(ValueError, msg): - export_io.export(src_io=read_io, container=read_foofile.buckets['bucket1']) + export_io.export( + src_io=read_io, + container=read_foofile.buckets["bucket1"], + ) def test_container_unknown(self): """Test that exporting a container that did not come from the src_io object raises an error.""" - foo1 = Foo('foo1', [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) - foobucket = FooBucket('bucket1', [foo1]) + foo1 = Foo("foo1", [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) + foobucket = FooBucket("bucket1", [foo1]) foofile = FooFile(buckets=[foobucket]) - with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode='w') as write_io: + with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode="w") as write_io: write_io.write(foofile) - with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode='r') as read_io: - - with HDF5IO(self.paths[1], mode='w') as export_io: + with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode="r") as read_io: + with HDF5IO(self.paths[1], mode="w") as export_io: dummy_file = FooFile(buckets=[]) msg = "The provided container must have been read by the provided src_io." with self.assertRaisesWith(ValueError, msg): @@ -2354,62 +2614,61 @@ def test_container_unknown(self): def test_cache_spec_true(self): """Test that exporting with cache_spec works.""" - foo1 = Foo('foo1', [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) - foobucket = FooBucket('bucket1', [foo1]) + foo1 = Foo("foo1", [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) + foobucket = FooBucket("bucket1", [foo1]) foofile = FooFile(buckets=[foobucket]) - with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode='w') as write_io: + with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode="w") as write_io: write_io.write(foofile) - with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode='r') as read_io: + with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode="r") as read_io: read_foofile = read_io.read() - with HDF5IO(self.paths[1], mode='w') as export_io: + with HDF5IO(self.paths[1], mode="w") as export_io: export_io.export( src_io=read_io, container=read_foofile, ) - with File(self.paths[1], 'r') as f: + with File(self.paths[1], "r") as f: self.assertIn("test_core", f["specifications"]) def test_cache_spec_false(self): """Test that exporting with cache_spec works.""" - foo1 = Foo('foo1', [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) - foobucket = FooBucket('bucket1', [foo1]) + foo1 = Foo("foo1", [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) + foobucket = FooBucket("bucket1", [foo1]) foofile = FooFile(buckets=[foobucket]) - with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode='w') as write_io: + with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode="w") as write_io: write_io.write(foofile) - with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode='r') as read_io: + with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode="r") as read_io: read_foofile = read_io.read() - with HDF5IO(self.paths[1], mode='w') as export_io: + with HDF5IO(self.paths[1], mode="w") as export_io: export_io.export( src_io=read_io, container=read_foofile, cache_spec=False, ) - with File(self.paths[1], 'r') as f: - self.assertNotIn('specifications', f) + with File(self.paths[1], "r") as f: + self.assertNotIn("specifications", f) def test_soft_link_group(self): """Test that exporting a written file with soft linked groups keeps links within the file.""" - foo1 = Foo('foo1', [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) - foobucket = FooBucket('bucket1', [foo1]) + foo1 = Foo("foo1", [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) + foobucket = FooBucket("bucket1", [foo1]) foofile = FooFile(buckets=[foobucket], foo_link=foo1) - with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode='w') as write_io: + with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode="w") as write_io: write_io.write(foofile) - with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode='r') as read_io: - - with HDF5IO(self.paths[1], mode='w') as export_io: + with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode="r") as read_io: + with HDF5IO(self.paths[1], mode="w") as export_io: export_io.export(src_io=read_io) - with HDF5IO(self.paths[1], manager=get_foo_buildmanager(), mode='r') as read_io: + with HDF5IO(self.paths[1], manager=get_foo_buildmanager(), mode="r") as read_io: self.ios.append(read_io) # track IO objects for tearDown read_foofile2 = read_io.read() @@ -2418,20 +2677,20 @@ def test_soft_link_group(self): def test_soft_link_dataset(self): """Test that exporting a written file with soft linked datasets keeps links within the file.""" - foo1 = Foo('foo1', [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) - foobucket = FooBucket('bucket1', [foo1]) + foo1 = Foo("foo1", [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) + foobucket = FooBucket("bucket1", [foo1]) foofile = FooFile(buckets=[foobucket], foofile_data=foo1.my_data) - with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode='w') as write_io: + with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode="w") as write_io: write_io.write(foofile) - with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode='r') as read_io: + with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode="r") as read_io: self.ios.append(read_io) # track IO objects for tearDown - with HDF5IO(self.paths[1], mode='w') as export_io: + with HDF5IO(self.paths[1], mode="w") as export_io: export_io.export(src_io=read_io) - with HDF5IO(self.paths[1], manager=get_foo_buildmanager(), mode='r') as read_io: + with HDF5IO(self.paths[1], manager=get_foo_buildmanager(), mode="r") as read_io: self.ios.append(read_io) # track IO objects for tearDown read_foofile2 = read_io.read() @@ -2440,21 +2699,21 @@ def test_soft_link_dataset(self): def test_soft_link_group_modified(self): """Test that exporting a written file with soft linked groups keeps links within the file.""" - foo1 = Foo('foo1', [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) - foobucket = FooBucket('bucket1', [foo1]) + foo1 = Foo("foo1", [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) + foobucket = FooBucket("bucket1", [foo1]) foofile = FooFile(buckets=[foobucket], foo_link=foo1) - with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode='w') as write_io: + with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode="w") as write_io: write_io.write(foofile) - with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode='r') as read_io: + with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode="r") as read_io: read_foofile2 = read_io.read() read_foofile2.foo_link.set_modified() # trigger a rebuild of foo_link and its parents - with HDF5IO(self.paths[1], mode='w') as export_io: + with HDF5IO(self.paths[1], mode="w") as export_io: export_io.export(src_io=read_io, container=read_foofile2) - with HDF5IO(self.paths[1], manager=get_foo_buildmanager(), mode='r') as read_io: + with HDF5IO(self.paths[1], manager=get_foo_buildmanager(), mode="r") as read_io: self.ios.append(read_io) # track IO objects for tearDown read_foofile2 = read_io.read() @@ -2462,30 +2721,30 @@ def test_soft_link_group_modified(self): self.assertEqual(read_foofile2.foo_link.container_source, self.paths[1]) # make sure the linked group is a soft link - with File(self.paths[1], 'r') as f: - self.assertEqual(f['links/foo_link'].file.filename, self.paths[1]) - self.assertIsInstance(f.get('links/foo_link', getlink=True), h5py.SoftLink) + with File(self.paths[1], "r") as f: + self.assertEqual(f["links/foo_link"].file.filename, self.paths[1]) + self.assertIsInstance(f.get("links/foo_link", getlink=True), h5py.SoftLink) def test_soft_link_group_modified_rel_path(self): """Test that exporting a written file with soft linked groups keeps links within the file.""" - foo1 = Foo('foo1', [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) - foobucket = FooBucket('bucket1', [foo1]) + foo1 = Foo("foo1", [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) + foobucket = FooBucket("bucket1", [foo1]) foofile = FooFile(buckets=[foobucket], foo_link=foo1) # make temp files in relative path location self.paths[0] = os.path.basename(self.paths[0]) self.paths[1] = os.path.basename(self.paths[1]) - with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode='w') as write_io: + with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode="w") as write_io: write_io.write(foofile) - with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode='r') as read_io: + with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode="r") as read_io: read_foofile2 = read_io.read() read_foofile2.foo_link.set_modified() # trigger a rebuild of foo_link and its parents - with HDF5IO(self.paths[1], mode='w') as export_io: + with HDF5IO(self.paths[1], mode="w") as export_io: export_io.export(src_io=read_io, container=read_foofile2) - with HDF5IO(self.paths[1], manager=get_foo_buildmanager(), mode='r') as read_io: + with HDF5IO(self.paths[1], manager=get_foo_buildmanager(), mode="r") as read_io: self.ios.append(read_io) # track IO objects for tearDown read_foofile2 = read_io.read() @@ -2493,35 +2752,35 @@ def test_soft_link_group_modified_rel_path(self): self.assertEqual(read_foofile2.foo_link.container_source, os.path.abspath(self.paths[1])) # make sure the linked group is a soft link - with File(self.paths[1], 'r') as f: - self.assertEqual(f['links/foo_link'].file.filename, self.paths[1]) - self.assertIsInstance(f.get('links/foo_link', getlink=True), h5py.SoftLink) + with File(self.paths[1], "r") as f: + self.assertEqual(f["links/foo_link"].file.filename, self.paths[1]) + self.assertIsInstance(f.get("links/foo_link", getlink=True), h5py.SoftLink) def test_external_link_group(self): """Test that exporting a written file with external linked groups maintains the links.""" - foo1 = Foo('foo1', [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) - foobucket = FooBucket('bucket1', [foo1]) + foo1 = Foo("foo1", [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) + foobucket = FooBucket("bucket1", [foo1]) foofile = FooFile(buckets=[foobucket]) - with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode='w') as read_io: + with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode="w") as read_io: read_io.write(foofile) manager = get_foo_buildmanager() - with HDF5IO(self.paths[0], manager=manager, mode='r') as read_io: + with HDF5IO(self.paths[0], manager=manager, mode="r") as read_io: read_foofile = read_io.read() # make external link to existing group - foofile2 = FooFile(foo_link=read_foofile.buckets['bucket1'].foos['foo1']) + foofile2 = FooFile(foo_link=read_foofile.buckets["bucket1"].foos["foo1"]) - with HDF5IO(self.paths[1], manager=manager, mode='w') as write_io: + with HDF5IO(self.paths[1], manager=manager, mode="w") as write_io: write_io.write(foofile2) - with HDF5IO(self.paths[1], manager=get_foo_buildmanager(), mode='r') as read_io: + with HDF5IO(self.paths[1], manager=get_foo_buildmanager(), mode="r") as read_io: self.ios.append(read_io) # track IO objects for tearDown - with HDF5IO(self.paths[2], mode='w') as export_io: + with HDF5IO(self.paths[2], mode="w") as export_io: export_io.export(src_io=read_io) - with HDF5IO(self.paths[2], manager=get_foo_buildmanager(), mode='r') as read_io: + with HDF5IO(self.paths[2], manager=get_foo_buildmanager(), mode="r") as read_io: self.ios.append(read_io) # track IO objects for tearDown read_foofile2 = read_io.read() @@ -2530,33 +2789,33 @@ def test_external_link_group(self): def test_external_link_group_rel_path(self): """Test that exporting a written file from a relative filepath works.""" - foo1 = Foo('foo1', [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) - foobucket = FooBucket('bucket1', [foo1]) + foo1 = Foo("foo1", [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) + foobucket = FooBucket("bucket1", [foo1]) foofile = FooFile(buckets=[foobucket]) # make temp files in relative path location self.paths[0] = os.path.basename(self.paths[0]) self.paths[1] = os.path.basename(self.paths[1]) self.paths[2] = os.path.basename(self.paths[2]) - with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode='w') as read_io: + with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode="w") as read_io: read_io.write(foofile) manager = get_foo_buildmanager() - with HDF5IO(self.paths[0], manager=manager, mode='r') as read_io: + with HDF5IO(self.paths[0], manager=manager, mode="r") as read_io: read_foofile = read_io.read() # make external link to existing group - foofile2 = FooFile(foo_link=read_foofile.buckets['bucket1'].foos['foo1']) + foofile2 = FooFile(foo_link=read_foofile.buckets["bucket1"].foos["foo1"]) - with HDF5IO(self.paths[1], manager=manager, mode='w') as write_io: + with HDF5IO(self.paths[1], manager=manager, mode="w") as write_io: write_io.write(foofile2) - with HDF5IO(self.paths[1], manager=get_foo_buildmanager(), mode='r') as read_io: + with HDF5IO(self.paths[1], manager=get_foo_buildmanager(), mode="r") as read_io: self.ios.append(read_io) # track IO objects for tearDown - with HDF5IO(self.paths[2], mode='w') as export_io: + with HDF5IO(self.paths[2], mode="w") as export_io: export_io.export(src_io=read_io) - with HDF5IO(self.paths[2], manager=get_foo_buildmanager(), mode='r') as read_io: + with HDF5IO(self.paths[2], manager=get_foo_buildmanager(), mode="r") as read_io: self.ios.append(read_io) # track IO objects for tearDown read_foofile2 = read_io.read() @@ -2565,28 +2824,28 @@ def test_external_link_group_rel_path(self): def test_external_link_dataset(self): """Test that exporting a written file with external linked datasets maintains the links.""" - foo1 = Foo('foo1', [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) - foobucket = FooBucket('bucket1', [foo1]) + foo1 = Foo("foo1", [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) + foobucket = FooBucket("bucket1", [foo1]) foofile = FooFile(buckets=[foobucket], foofile_data=[1, 2, 3]) - with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode='w') as write_io: + with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode="w") as write_io: write_io.write(foofile) manager = get_foo_buildmanager() - with HDF5IO(self.paths[0], manager=manager, mode='r') as read_io: + with HDF5IO(self.paths[0], manager=manager, mode="r") as read_io: read_foofile = read_io.read() foofile2 = FooFile(foofile_data=read_foofile.foofile_data) # make external link to existing dataset - with HDF5IO(self.paths[1], manager=manager, mode='w') as write_io: + with HDF5IO(self.paths[1], manager=manager, mode="w") as write_io: write_io.write(foofile2) - with HDF5IO(self.paths[1], manager=get_foo_buildmanager(), mode='r') as read_io: + with HDF5IO(self.paths[1], manager=get_foo_buildmanager(), mode="r") as read_io: self.ios.append(read_io) # track IO objects for tearDown - with HDF5IO(self.paths[2], mode='w') as export_io: + with HDF5IO(self.paths[2], mode="w") as export_io: export_io.export(src_io=read_io) - with HDF5IO(self.paths[2], manager=get_foo_buildmanager(), mode='r') as read_io: + with HDF5IO(self.paths[2], manager=get_foo_buildmanager(), mode="r") as read_io: self.ios.append(read_io) # track IO objects for tearDown read_foofile2 = read_io.read() @@ -2595,38 +2854,38 @@ def test_external_link_dataset(self): def test_external_link_link(self): """Test that exporting a written file with external links to external links maintains the links.""" - foo1 = Foo('foo1', [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) - foobucket = FooBucket('bucket1', [foo1]) + foo1 = Foo("foo1", [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) + foobucket = FooBucket("bucket1", [foo1]) foofile = FooFile(buckets=[foobucket]) - with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode='w') as write_io: + with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode="w") as write_io: write_io.write(foofile) manager = get_foo_buildmanager() - with HDF5IO(self.paths[0], manager=manager, mode='r') as read_io: + with HDF5IO(self.paths[0], manager=manager, mode="r") as read_io: read_foofile = read_io.read() # make external link to existing group - foofile2 = FooFile(foo_link=read_foofile.buckets['bucket1'].foos['foo1']) + foofile2 = FooFile(foo_link=read_foofile.buckets["bucket1"].foos["foo1"]) - with HDF5IO(self.paths[1], manager=manager, mode='w') as write_io: + with HDF5IO(self.paths[1], manager=manager, mode="w") as write_io: write_io.write(foofile2) manager = get_foo_buildmanager() - with HDF5IO(self.paths[1], manager=manager, mode='r') as read_io: + with HDF5IO(self.paths[1], manager=manager, mode="r") as read_io: self.ios.append(read_io) # track IO objects for tearDown read_foofile2 = read_io.read() foofile3 = FooFile(foo_link=read_foofile2.foo_link) # make external link to external link - with HDF5IO(self.paths[2], manager=manager, mode='w') as write_io: + with HDF5IO(self.paths[2], manager=manager, mode="w") as write_io: write_io.write(foofile3) - with HDF5IO(self.paths[2], manager=get_foo_buildmanager(), mode='r') as read_io: + with HDF5IO(self.paths[2], manager=get_foo_buildmanager(), mode="r") as read_io: self.ios.append(read_io) # track IO objects for tearDown - with HDF5IO(self.paths[3], mode='w') as export_io: + with HDF5IO(self.paths[3], mode="w") as export_io: export_io.export(src_io=read_io) - with HDF5IO(self.paths[3], manager=get_foo_buildmanager(), mode='r') as read_io: + with HDF5IO(self.paths[3], manager=get_foo_buildmanager(), mode="r") as read_io: self.ios.append(read_io) # track IO objects for tearDown read_foofile3 = read_io.read() @@ -2635,23 +2894,23 @@ def test_external_link_link(self): def test_new_soft_link(self): """Test that exporting a file with a newly created soft link makes the link internally.""" - foo1 = Foo('foo1', [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) - foobucket = FooBucket('bucket1', [foo1]) + foo1 = Foo("foo1", [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) + foobucket = FooBucket("bucket1", [foo1]) foofile = FooFile(buckets=[foobucket]) - with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode='w') as write_io: + with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode="w") as write_io: write_io.write(foofile) manager = get_foo_buildmanager() - with HDF5IO(self.paths[0], manager=manager, mode='r') as read_io: + with HDF5IO(self.paths[0], manager=manager, mode="r") as read_io: read_foofile = read_io.read() # make external link to existing group - read_foofile.foo_link = read_foofile.buckets['bucket1'].foos['foo1'] + read_foofile.foo_link = read_foofile.buckets["bucket1"].foos["foo1"] - with HDF5IO(self.paths[1], mode='w') as export_io: + with HDF5IO(self.paths[1], mode="w") as export_io: export_io.export(src_io=read_io, container=read_foofile) - with HDF5IO(self.paths[1], manager=get_foo_buildmanager(), mode='r') as read_io: + with HDF5IO(self.paths[1], manager=get_foo_buildmanager(), mode="r") as read_io: self.ios.append(read_io) # track IO objects for tearDown read_foofile2 = read_io.read() @@ -2660,44 +2919,46 @@ def test_new_soft_link(self): def test_attr_reference(self): """Test that exporting a written file with attribute references maintains the references.""" - foo1 = Foo('foo1', [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) - foobucket = FooBucket('bucket1', [foo1]) + foo1 = Foo("foo1", [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) + foobucket = FooBucket("bucket1", [foo1]) foofile = FooFile(buckets=[foobucket], foo_ref_attr=foo1) - with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode='w') as read_io: + with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode="w") as read_io: read_io.write(foofile) - with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode='r') as read_io: - - with HDF5IO(self.paths[1], mode='w') as export_io: + with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode="r") as read_io: + with HDF5IO(self.paths[1], mode="w") as export_io: export_io.export(src_io=read_io) - with HDF5IO(self.paths[1], manager=get_foo_buildmanager(), mode='r') as read_io: + with HDF5IO(self.paths[1], manager=get_foo_buildmanager(), mode="r") as read_io: read_foofile2 = read_io.read() # make sure the attribute reference resolves to the container within the same file - self.assertIs(read_foofile2.foo_ref_attr, read_foofile2.buckets['bucket1'].foos['foo1']) + self.assertIs( + read_foofile2.foo_ref_attr, + read_foofile2.buckets["bucket1"].foos["foo1"], + ) - with File(self.paths[1], 'r') as f: - self.assertIsInstance(f.attrs['foo_ref_attr'], h5py.Reference) + with File(self.paths[1], "r") as f: + self.assertIsInstance(f.attrs["foo_ref_attr"], h5py.Reference) def test_pop_data(self): """Test that exporting a written container after removing an element from it works.""" - foo1 = Foo('foo1', [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) - foobucket = FooBucket('bucket1', [foo1]) + foo1 = Foo("foo1", [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) + foobucket = FooBucket("bucket1", [foo1]) foofile = FooFile(buckets=[foobucket]) - with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode='w') as write_io: + with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode="w") as write_io: write_io.write(foofile) - with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode='r') as read_io: + with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode="r") as read_io: read_foofile = read_io.read() - read_foofile.remove_bucket('bucket1') # remove child group + read_foofile.remove_bucket("bucket1") # remove child group - with HDF5IO(self.paths[1], mode='w') as export_io: + with HDF5IO(self.paths[1], mode="w") as export_io: export_io.export(src_io=read_io, container=read_foofile) - with HDF5IO(self.paths[1], manager=get_foo_buildmanager(), mode='r') as read_io: + with HDF5IO(self.paths[1], manager=get_foo_buildmanager(), mode="r") as read_io: read_foofile2 = read_io.read() # make sure the read foofile has no buckets @@ -2708,39 +2969,47 @@ def test_pop_data(self): def test_pop_linked_group(self): """Test that exporting a written container after removing a linked element from it works.""" - foo1 = Foo('foo1', [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) - foobucket = FooBucket('bucket1', [foo1]) + foo1 = Foo("foo1", [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) + foobucket = FooBucket("bucket1", [foo1]) foofile = FooFile(buckets=[foobucket], foo_link=foo1) - with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode='w') as write_io: + with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode="w") as write_io: write_io.write(foofile) - with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode='r') as read_io: + with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode="r") as read_io: read_foofile = read_io.read() - read_foofile.buckets['bucket1'].remove_foo('foo1') # remove child group + read_foofile.buckets["bucket1"].remove_foo("foo1") # remove child group - with HDF5IO(self.paths[1], mode='w') as export_io: - msg = ("links (links): Linked Foo 'foo1' has no parent. Remove the link or ensure the linked " - "container is added properly.") + with HDF5IO(self.paths[1], mode="w") as export_io: + msg = ( + "links (links): Linked Foo 'foo1' has no parent. Remove the link or" + " ensure the linked container is added properly." + ) with self.assertRaisesWith(OrphanContainerBuildError, msg): export_io.export(src_io=read_io, container=read_foofile) def test_append_data(self): """Test that exporting a written container after adding groups, links, and references to it works.""" - foo1 = Foo('foo1', [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) - foobucket = FooBucket('bucket1', [foo1]) + foo1 = Foo("foo1", [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) + foobucket = FooBucket("bucket1", [foo1]) foofile = FooFile(buckets=[foobucket]) - with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode='w') as write_io: + with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode="w") as write_io: write_io.write(foofile) - with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode='r') as read_io: + with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode="r") as read_io: read_foofile = read_io.read() # create a foo with link to existing dataset my_data, add the foo to new foobucket # this should make a soft link within the exported file - foo2 = Foo('foo2', read_foofile.buckets['bucket1'].foos['foo1'].my_data, "I am foo2", 17, 3.14) - foobucket2 = FooBucket('bucket2', [foo2]) + foo2 = Foo( + "foo2", + read_foofile.buckets["bucket1"].foos["foo1"].my_data, + "I am foo2", + 17, + 3.14, + ) + foobucket2 = FooBucket("bucket2", [foo2]) read_foofile.add_bucket(foobucket2) # also add link from foofile to new foo2 container @@ -2752,298 +3021,351 @@ def test_append_data(self): # also add reference from foofile to new foo2 read_foofile.foo_ref_attr = foo2 - with HDF5IO(self.paths[1], mode='w') as export_io: + with HDF5IO(self.paths[1], mode="w") as export_io: export_io.export(src_io=read_io, container=read_foofile) - with HDF5IO(self.paths[1], manager=get_foo_buildmanager(), mode='r') as read_io: + with HDF5IO(self.paths[1], manager=get_foo_buildmanager(), mode="r") as read_io: self.ios.append(read_io) # track IO objects for tearDown read_foofile2 = read_io.read() # test new soft link to dataset in file - self.assertIs(read_foofile2.buckets['bucket1'].foos['foo1'].my_data, - read_foofile2.buckets['bucket2'].foos['foo2'].my_data) + self.assertIs( + read_foofile2.buckets["bucket1"].foos["foo1"].my_data, + read_foofile2.buckets["bucket2"].foos["foo2"].my_data, + ) # test new soft link to group in file - self.assertIs(read_foofile2.foo_link, read_foofile2.buckets['bucket2'].foos['foo2']) + self.assertIs( + read_foofile2.foo_link, + read_foofile2.buckets["bucket2"].foos["foo2"], + ) # test new soft link to new soft link to dataset in file - self.assertIs(read_foofile2.buckets['bucket1'].foos['foo1'].my_data, read_foofile2.foofile_data) + self.assertIs( + read_foofile2.buckets["bucket1"].foos["foo1"].my_data, + read_foofile2.foofile_data, + ) # test new attribute reference to new group in file - self.assertIs(read_foofile2.foo_ref_attr, read_foofile2.buckets['bucket2'].foos['foo2']) + self.assertIs( + read_foofile2.foo_ref_attr, + read_foofile2.buckets["bucket2"].foos["foo2"], + ) - with File(self.paths[1], 'r') as f: - self.assertEqual(f['foofile_data'].file.filename, self.paths[1]) - self.assertIsInstance(f.attrs['foo_ref_attr'], h5py.Reference) + with File(self.paths[1], "r") as f: + self.assertEqual(f["foofile_data"].file.filename, self.paths[1]) + self.assertIsInstance(f.attrs["foo_ref_attr"], h5py.Reference) def test_append_external_link_data(self): """Test that exporting a written container after adding a link with link_data=True creates external links.""" - foo1 = Foo('foo1', [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) - foobucket = FooBucket('bucket1', [foo1]) + foo1 = Foo("foo1", [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) + foobucket = FooBucket("bucket1", [foo1]) foofile = FooFile(buckets=[foobucket]) - with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode='w') as write_io: + with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode="w") as write_io: write_io.write(foofile) foofile2 = FooFile(buckets=[]) - with HDF5IO(self.paths[1], manager=get_foo_buildmanager(), mode='w') as write_io: + with HDF5IO(self.paths[1], manager=get_foo_buildmanager(), mode="w") as write_io: write_io.write(foofile2) manager = get_foo_buildmanager() - with HDF5IO(self.paths[0], manager=manager, mode='r') as read_io1: + with HDF5IO(self.paths[0], manager=manager, mode="r") as read_io1: self.ios.append(read_io1) # track IO objects for tearDown read_foofile1 = read_io1.read() - with HDF5IO(self.paths[1], manager=manager, mode='r') as read_io2: + with HDF5IO(self.paths[1], manager=manager, mode="r") as read_io2: self.ios.append(read_io2) read_foofile2 = read_io2.read() # create a foo with link to existing dataset my_data (not in same file), add the foo to new foobucket # this should make an external link within the exported file - foo2 = Foo('foo2', read_foofile1.buckets['bucket1'].foos['foo1'].my_data, "I am foo2", 17, 3.14) - foobucket2 = FooBucket('bucket2', [foo2]) + foo2 = Foo( + "foo2", + read_foofile1.buckets["bucket1"].foos["foo1"].my_data, + "I am foo2", + 17, + 3.14, + ) + foobucket2 = FooBucket("bucket2", [foo2]) read_foofile2.add_bucket(foobucket2) # also add link from foofile to new foo2.my_data dataset which is a link to foo1.my_data dataset # this should make an external link within the exported file read_foofile2.foofile_data = foo2.my_data - with HDF5IO(self.paths[2], mode='w') as export_io: + with HDF5IO(self.paths[2], mode="w") as export_io: export_io.export(src_io=read_io2, container=read_foofile2) - with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode='r') as read_io1: + with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode="r") as read_io1: self.ios.append(read_io1) # track IO objects for tearDown read_foofile3 = read_io1.read() - with HDF5IO(self.paths[2], manager=get_foo_buildmanager(), mode='r') as read_io2: + with HDF5IO(self.paths[2], manager=get_foo_buildmanager(), mode="r") as read_io2: self.ios.append(read_io2) # track IO objects for tearDown read_foofile4 = read_io2.read() - self.assertEqual(read_foofile4.buckets['bucket2'].foos['foo2'].my_data, - read_foofile3.buckets['bucket1'].foos['foo1'].my_data) - self.assertEqual(read_foofile4.foofile_data, read_foofile3.buckets['bucket1'].foos['foo1'].my_data) + self.assertEqual( + read_foofile4.buckets["bucket2"].foos["foo2"].my_data, + read_foofile3.buckets["bucket1"].foos["foo1"].my_data, + ) + self.assertEqual( + read_foofile4.foofile_data, + read_foofile3.buckets["bucket1"].foos["foo1"].my_data, + ) - with File(self.paths[2], 'r') as f: - self.assertEqual(f['buckets/bucket2/foo_holder/foo2/my_data'].file.filename, self.paths[0]) - self.assertEqual(f['foofile_data'].file.filename, self.paths[0]) - self.assertIsInstance(f.get('buckets/bucket2/foo_holder/foo2/my_data', getlink=True), - h5py.ExternalLink) - self.assertIsInstance(f.get('foofile_data', getlink=True), h5py.ExternalLink) + with File(self.paths[2], "r") as f: + self.assertEqual( + f["buckets/bucket2/foo_holder/foo2/my_data"].file.filename, + self.paths[0], + ) + self.assertEqual(f["foofile_data"].file.filename, self.paths[0]) + self.assertIsInstance( + f.get("buckets/bucket2/foo_holder/foo2/my_data", getlink=True), + h5py.ExternalLink, + ) + self.assertIsInstance(f.get("foofile_data", getlink=True), h5py.ExternalLink) def test_append_external_link_copy_data(self): """Test that exporting a written container after adding a link with link_data=False copies the data.""" - foo1 = Foo('foo1', [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) - foobucket = FooBucket('bucket1', [foo1]) + foo1 = Foo("foo1", [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) + foobucket = FooBucket("bucket1", [foo1]) foofile = FooFile(buckets=[foobucket]) - with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode='w') as write_io: + with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode="w") as write_io: write_io.write(foofile) foofile2 = FooFile(buckets=[]) - with HDF5IO(self.paths[1], manager=get_foo_buildmanager(), mode='w') as write_io: + with HDF5IO(self.paths[1], manager=get_foo_buildmanager(), mode="w") as write_io: write_io.write(foofile2) manager = get_foo_buildmanager() - with HDF5IO(self.paths[0], manager=manager, mode='r') as read_io1: + with HDF5IO(self.paths[0], manager=manager, mode="r") as read_io1: self.ios.append(read_io1) # track IO objects for tearDown read_foofile1 = read_io1.read() - with HDF5IO(self.paths[1], manager=manager, mode='r') as read_io2: + with HDF5IO(self.paths[1], manager=manager, mode="r") as read_io2: self.ios.append(read_io2) read_foofile2 = read_io2.read() # create a foo with link to existing dataset my_data (not in same file), add the foo to new foobucket # this would normally make an external link but because link_data=False, data will be copied - foo2 = Foo('foo2', read_foofile1.buckets['bucket1'].foos['foo1'].my_data, "I am foo2", 17, 3.14) - foobucket2 = FooBucket('bucket2', [foo2]) + foo2 = Foo( + "foo2", + read_foofile1.buckets["bucket1"].foos["foo1"].my_data, + "I am foo2", + 17, + 3.14, + ) + foobucket2 = FooBucket("bucket2", [foo2]) read_foofile2.add_bucket(foobucket2) # also add link from foofile to new foo2.my_data dataset which is a link to foo1.my_data dataset # this would normally make an external link but because link_data=False, data will be copied read_foofile2.foofile_data = foo2.my_data - with HDF5IO(self.paths[2], mode='w') as export_io: - export_io.export(src_io=read_io2, container=read_foofile2, write_args={'link_data': False}) + with HDF5IO(self.paths[2], mode="w") as export_io: + export_io.export( + src_io=read_io2, + container=read_foofile2, + write_args={"link_data": False}, + ) - with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode='r') as read_io1: + with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode="r") as read_io1: self.ios.append(read_io1) # track IO objects for tearDown read_foofile3 = read_io1.read() - with HDF5IO(self.paths[2], manager=get_foo_buildmanager(), mode='r') as read_io2: + with HDF5IO(self.paths[2], manager=get_foo_buildmanager(), mode="r") as read_io2: self.ios.append(read_io2) # track IO objects for tearDown read_foofile4 = read_io2.read() # check that file can be read - self.assertNotEqual(read_foofile4.buckets['bucket2'].foos['foo2'].my_data, - read_foofile3.buckets['bucket1'].foos['foo1'].my_data) - self.assertNotEqual(read_foofile4.foofile_data, read_foofile3.buckets['bucket1'].foos['foo1'].my_data) - self.assertNotEqual(read_foofile4.foofile_data, read_foofile4.buckets['bucket2'].foos['foo2'].my_data) + self.assertNotEqual( + read_foofile4.buckets["bucket2"].foos["foo2"].my_data, + read_foofile3.buckets["bucket1"].foos["foo1"].my_data, + ) + self.assertNotEqual( + read_foofile4.foofile_data, + read_foofile3.buckets["bucket1"].foos["foo1"].my_data, + ) + self.assertNotEqual( + read_foofile4.foofile_data, + read_foofile4.buckets["bucket2"].foos["foo2"].my_data, + ) - with File(self.paths[2], 'r') as f: - self.assertEqual(f['buckets/bucket2/foo_holder/foo2/my_data'].file.filename, self.paths[2]) - self.assertEqual(f['foofile_data'].file.filename, self.paths[2]) + with File(self.paths[2], "r") as f: + self.assertEqual( + f["buckets/bucket2/foo_holder/foo2/my_data"].file.filename, + self.paths[2], + ) + self.assertEqual(f["foofile_data"].file.filename, self.paths[2]) def test_export_simple_link_data(self): """Test simple exporting of data with a link with link_data=True links the data.""" - foo1 = Foo('foo1', [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) - foobucket = FooBucket('bucket1', [foo1]) + foo1 = Foo("foo1", [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) + foobucket = FooBucket("bucket1", [foo1]) foofile = FooFile([foobucket]) - with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode='w') as write_io: + with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode="w") as write_io: write_io.write(foofile) # create new foofile with link from foo2.data to read foo1.data - with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode='r') as read_io: + with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode="r") as read_io: read_foofile1 = read_io.read() - foo2 = Foo('foo2', read_foofile1.buckets['bucket1'].foos['foo1'].my_data, "I am foo2", 17, 3.14) - foobucket2 = FooBucket('bucket2', [foo2]) + foo2 = Foo("foo2", read_foofile1.buckets["bucket1"].foos["foo1"].my_data, "I am foo2", 17, 3.14) + foobucket2 = FooBucket("bucket2", [foo2]) foofile2 = FooFile([foobucket2]) # also add link from foofile to new foo2.my_data dataset which is a link to foo1.my_data dataset # this should make an external link within the exported file foofile2.foofile_data = foo2.my_data - with HDF5IO(self.paths[1], manager=get_foo_buildmanager(), mode='w') as write_io: + with HDF5IO(self.paths[1], manager=get_foo_buildmanager(), mode="w") as write_io: write_io.write(foofile2) # read the data with the linked dataset, do not modify it, and export it - with HDF5IO(self.paths[1], manager=get_foo_buildmanager(), mode='r') as read_io: - with HDF5IO(self.paths[2], mode='w') as export_io: + with HDF5IO(self.paths[1], manager=get_foo_buildmanager(), mode="r") as read_io: + with HDF5IO(self.paths[2], mode="w") as export_io: export_io.export(src_io=read_io) # read the exported file and confirm that the dataset is linked to the correct foofile1 - with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode='r') as read_io1: + with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode="r") as read_io1: self.ios.append(read_io1) # track IO objects for tearDown read_foofile3 = read_io1.read() - with HDF5IO(self.paths[2], manager=get_foo_buildmanager(), mode='r') as read_io2: + with HDF5IO(self.paths[2], manager=get_foo_buildmanager(), mode="r") as read_io2: self.ios.append(read_io2) # track IO objects for tearDown read_foofile4 = read_io2.read() - self.assertEqual(read_foofile4.buckets['bucket2'].foos['foo2'].my_data, - read_foofile3.buckets['bucket1'].foos['foo1'].my_data) - self.assertEqual(read_foofile4.foofile_data, read_foofile3.buckets['bucket1'].foos['foo1'].my_data) + self.assertEqual( + read_foofile4.buckets["bucket2"].foos["foo2"].my_data, + read_foofile3.buckets["bucket1"].foos["foo1"].my_data, + ) + self.assertEqual(read_foofile4.foofile_data, read_foofile3.buckets["bucket1"].foos["foo1"].my_data) - with File(self.paths[2], 'r') as f: - self.assertEqual(f['buckets/bucket2/foo_holder/foo2/my_data'].file.filename, self.paths[0]) - self.assertEqual(f['foofile_data'].file.filename, self.paths[0]) - self.assertIsInstance(f.get('buckets/bucket2/foo_holder/foo2/my_data', getlink=True), - h5py.ExternalLink) - self.assertIsInstance(f.get('foofile_data', getlink=True), h5py.ExternalLink) + with File(self.paths[2], "r") as f: + self.assertEqual(f["buckets/bucket2/foo_holder/foo2/my_data"].file.filename, self.paths[0]) + self.assertEqual(f["foofile_data"].file.filename, self.paths[0]) + self.assertIsInstance(f.get("buckets/bucket2/foo_holder/foo2/my_data", getlink=True), h5py.ExternalLink) + self.assertIsInstance(f.get("foofile_data", getlink=True), h5py.ExternalLink) def test_export_simple_link_data_false(self): """Test simple exporting of data with a link with link_data=False copies the data.""" - foo1 = Foo('foo1', [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) - foobucket = FooBucket('bucket1', [foo1]) + foo1 = Foo("foo1", [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) + foobucket = FooBucket("bucket1", [foo1]) foofile = FooFile([foobucket]) - with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode='w') as write_io: + with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode="w") as write_io: write_io.write(foofile) # create new foofile with link from foo2.data to read foo1.data - with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode='r') as read_io: + with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode="r") as read_io: read_foofile1 = read_io.read() - foo2 = Foo('foo2', read_foofile1.buckets['bucket1'].foos['foo1'].my_data, "I am foo2", 17, 3.14) - foobucket2 = FooBucket('bucket2', [foo2]) + foo2 = Foo("foo2", read_foofile1.buckets["bucket1"].foos["foo1"].my_data, "I am foo2", 17, 3.14) + foobucket2 = FooBucket("bucket2", [foo2]) foofile2 = FooFile([foobucket2]) # also add link from foofile to new foo2.my_data dataset which is a link to foo1.my_data dataset # this should make an external link within the exported file foofile2.foofile_data = foo2.my_data - with HDF5IO(self.paths[1], manager=get_foo_buildmanager(), mode='w') as write_io: + with HDF5IO(self.paths[1], manager=get_foo_buildmanager(), mode="w") as write_io: write_io.write(foofile2) # read the data with the linked dataset, do not modify it, and export it - with HDF5IO(self.paths[1], manager=get_foo_buildmanager(), mode='r') as read_io: - with HDF5IO(self.paths[2], mode='w') as export_io: - export_io.export(src_io=read_io, write_args={'link_data': False}) + with HDF5IO(self.paths[1], manager=get_foo_buildmanager(), mode="r") as read_io: + with HDF5IO(self.paths[2], mode="w") as export_io: + export_io.export(src_io=read_io, write_args={"link_data": False}) # read the exported file and confirm that the dataset is copied - with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode='r') as read_io1: + with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode="r") as read_io1: self.ios.append(read_io1) # track IO objects for tearDown read_foofile3 = read_io1.read() - with HDF5IO(self.paths[2], manager=get_foo_buildmanager(), mode='r') as read_io2: + with HDF5IO(self.paths[2], manager=get_foo_buildmanager(), mode="r") as read_io2: self.ios.append(read_io2) # track IO objects for tearDown read_foofile4 = read_io2.read() # check that file can be read - self.assertNotEqual(read_foofile4.buckets['bucket2'].foos['foo2'].my_data, - read_foofile3.buckets['bucket1'].foos['foo1'].my_data) - self.assertNotEqual(read_foofile4.foofile_data, read_foofile3.buckets['bucket1'].foos['foo1'].my_data) - self.assertNotEqual(read_foofile4.foofile_data, read_foofile4.buckets['bucket2'].foos['foo2'].my_data) + self.assertNotEqual( + read_foofile4.buckets["bucket2"].foos["foo2"].my_data, + read_foofile3.buckets["bucket1"].foos["foo1"].my_data, + ) + self.assertNotEqual(read_foofile4.foofile_data, read_foofile3.buckets["bucket1"].foos["foo1"].my_data) + self.assertNotEqual(read_foofile4.foofile_data, read_foofile4.buckets["bucket2"].foos["foo2"].my_data) - with File(self.paths[2], 'r') as f: - self.assertEqual(f['buckets/bucket2/foo_holder/foo2/my_data'].file.filename, self.paths[2]) - self.assertEqual(f['foofile_data'].file.filename, self.paths[2]) + with File(self.paths[2], "r") as f: + self.assertEqual(f["buckets/bucket2/foo_holder/foo2/my_data"].file.filename, self.paths[2]) + self.assertEqual(f["foofile_data"].file.filename, self.paths[2]) def test_export_simple_with_container_link_data_false(self): """Test simple exporting of data with a link with link_data=False copies the data.""" - foo1 = Foo('foo1', [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) - foobucket = FooBucket('bucket1', [foo1]) + foo1 = Foo("foo1", [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) + foobucket = FooBucket("bucket1", [foo1]) foofile = FooFile([foobucket]) - with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode='w') as write_io: + with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode="w") as write_io: write_io.write(foofile) # create new foofile with link from foo2.data to read foo1.data - with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode='r') as read_io: + with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode="r") as read_io: read_foofile1 = read_io.read() - foo2 = Foo('foo2', read_foofile1.buckets['bucket1'].foos['foo1'].my_data, "I am foo2", 17, 3.14) - foobucket2 = FooBucket('bucket2', [foo2]) + foo2 = Foo("foo2", read_foofile1.buckets["bucket1"].foos["foo1"].my_data, "I am foo2", 17, 3.14) + foobucket2 = FooBucket("bucket2", [foo2]) foofile2 = FooFile([foobucket2]) # also add link from foofile to new foo2.my_data dataset which is a link to foo1.my_data dataset # this should make an external link within the exported file foofile2.foofile_data = foo2.my_data - with HDF5IO(self.paths[1], manager=get_foo_buildmanager(), mode='w') as write_io: + with HDF5IO(self.paths[1], manager=get_foo_buildmanager(), mode="w") as write_io: write_io.write(foofile2) # read the data with the linked dataset, do not modify it, and export it - with HDF5IO(self.paths[1], manager=get_foo_buildmanager(), mode='r') as read_io: + with HDF5IO(self.paths[1], manager=get_foo_buildmanager(), mode="r") as read_io: read_foofile2 = read_io.read() - with HDF5IO(self.paths[2], mode='w') as export_io: - export_io.export(src_io=read_io, container=read_foofile2, write_args={'link_data': False}) + with HDF5IO(self.paths[2], mode="w") as export_io: + export_io.export(src_io=read_io, container=read_foofile2, write_args={"link_data": False}) # read the exported file and confirm that the dataset is copied - with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode='r') as read_io1: + with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode="r") as read_io1: self.ios.append(read_io1) # track IO objects for tearDown read_foofile3 = read_io1.read() - with HDF5IO(self.paths[2], manager=get_foo_buildmanager(), mode='r') as read_io2: + with HDF5IO(self.paths[2], manager=get_foo_buildmanager(), mode="r") as read_io2: self.ios.append(read_io2) # track IO objects for tearDown read_foofile4 = read_io2.read() # check that file can be read - self.assertNotEqual(read_foofile4.buckets['bucket2'].foos['foo2'].my_data, - read_foofile3.buckets['bucket1'].foos['foo1'].my_data) - self.assertNotEqual(read_foofile4.foofile_data, read_foofile3.buckets['bucket1'].foos['foo1'].my_data) - self.assertNotEqual(read_foofile4.foofile_data, read_foofile4.buckets['bucket2'].foos['foo2'].my_data) + self.assertNotEqual( + read_foofile4.buckets["bucket2"].foos["foo2"].my_data, + read_foofile3.buckets["bucket1"].foos["foo1"].my_data, + ) + self.assertNotEqual(read_foofile4.foofile_data, read_foofile3.buckets["bucket1"].foos["foo1"].my_data) + self.assertNotEqual(read_foofile4.foofile_data, read_foofile4.buckets["bucket2"].foos["foo2"].my_data) - with File(self.paths[2], 'r') as f: - self.assertEqual(f['buckets/bucket2/foo_holder/foo2/my_data'].file.filename, self.paths[2]) - self.assertEqual(f['foofile_data'].file.filename, self.paths[2]) + with File(self.paths[2], "r") as f: + self.assertEqual(f["buckets/bucket2/foo_holder/foo2/my_data"].file.filename, self.paths[2]) + self.assertEqual(f["foofile_data"].file.filename, self.paths[2]) def test_export_io(self): """Test that exporting a written container using HDF5IO.export_io works.""" - foo1 = Foo('foo1', [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) - foobucket = FooBucket('bucket1', [foo1]) + foo1 = Foo("foo1", [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) + foobucket = FooBucket("bucket1", [foo1]) foofile = FooFile(buckets=[foobucket]) - with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode='w') as write_io: + with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode="w") as write_io: write_io.write(foofile) - with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode='r') as read_io: + with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode="r") as read_io: HDF5IO.export_io(src_io=read_io, path=self.paths[1]) self.assertTrue(os.path.exists(self.paths[1])) self.assertEqual(foofile.container_source, self.paths[0]) - with HDF5IO(self.paths[1], manager=get_foo_buildmanager(), mode='r') as read_io: + with HDF5IO(self.paths[1], manager=get_foo_buildmanager(), mode="r") as read_io: read_foofile = read_io.read() self.assertEqual(read_foofile.container_source, self.paths[1]) self.assertContainerEqual(foofile, read_foofile, ignore_hdmf_attrs=True) @@ -3053,34 +3375,34 @@ def test_export_dset_refs(self): bazs = [] num_bazs = 10 for i in range(num_bazs): - bazs.append(Baz(name='baz%d' % i)) - baz_data = BazData(name='baz_data1', data=bazs) - bucket = BazBucket(name='bucket1', bazs=bazs.copy(), baz_data=baz_data) + bazs.append(Baz(name="baz%d" % i)) + baz_data = BazData(name="baz_data1", data=bazs) + bucket = BazBucket(name="bucket1", bazs=bazs.copy(), baz_data=baz_data) - with HDF5IO(self.paths[0], manager=get_baz_buildmanager(), mode='w') as write_io: + with HDF5IO(self.paths[0], manager=get_baz_buildmanager(), mode="w") as write_io: write_io.write(bucket) - with HDF5IO(self.paths[0], manager=get_baz_buildmanager(), mode='r') as read_io: + with HDF5IO(self.paths[0], manager=get_baz_buildmanager(), mode="r") as read_io: read_bucket1 = read_io.read() # NOTE: reference IDs might be the same between two identical files # adding a Baz with a smaller name should change the reference IDs on export - new_baz = Baz(name='baz000') + new_baz = Baz(name="baz000") read_bucket1.add_baz(new_baz) - with HDF5IO(self.paths[1], mode='w') as export_io: + with HDF5IO(self.paths[1], mode="w") as export_io: export_io.export(src_io=read_io, container=read_bucket1) - with HDF5IO(self.paths[1], manager=get_baz_buildmanager(), mode='r') as read_io: + with HDF5IO(self.paths[1], manager=get_baz_buildmanager(), mode="r") as read_io: read_bucket2 = read_io.read() # remove and check the appended child, then compare the read container with the original - read_new_baz = read_bucket2.remove_baz('baz000') + read_new_baz = read_bucket2.remove_baz("baz000") self.assertContainerEqual(new_baz, read_new_baz, ignore_hdmf_attrs=True) self.assertContainerEqual(bucket, read_bucket2, ignore_name=True, ignore_hdmf_attrs=True) for i in range(num_bazs): - baz_name = 'baz%d' % i + baz_name = "baz%d" % i self.assertIs(read_bucket2.baz_data.data[i], read_bucket2.bazs[baz_name]) def test_export_cpd_dset_refs(self): @@ -3089,27 +3411,27 @@ def test_export_cpd_dset_refs(self): baz_pairs = [] num_bazs = 10 for i in range(num_bazs): - b = Baz(name='baz%d' % i) + b = Baz(name="baz%d" % i) bazs.append(b) baz_pairs.append((i, b)) - baz_cpd_data = BazCpdData(name='baz_cpd_data1', data=baz_pairs) - bucket = BazBucket(name='bucket1', bazs=bazs.copy(), baz_cpd_data=baz_cpd_data) + baz_cpd_data = BazCpdData(name="baz_cpd_data1", data=baz_pairs) + bucket = BazBucket(name="bucket1", bazs=bazs.copy(), baz_cpd_data=baz_cpd_data) - with HDF5IO(self.paths[0], manager=get_baz_buildmanager(), mode='w') as write_io: + with HDF5IO(self.paths[0], manager=get_baz_buildmanager(), mode="w") as write_io: write_io.write(bucket) - with HDF5IO(self.paths[0], manager=get_baz_buildmanager(), mode='r') as read_io: + with HDF5IO(self.paths[0], manager=get_baz_buildmanager(), mode="r") as read_io: read_bucket1 = read_io.read() # NOTE: reference IDs might be the same between two identical files # adding a Baz with a smaller name should change the reference IDs on export - new_baz = Baz(name='baz000') + new_baz = Baz(name="baz000") read_bucket1.add_baz(new_baz) - with HDF5IO(self.paths[1], mode='w') as export_io: + with HDF5IO(self.paths[1], mode="w") as export_io: export_io.export(src_io=read_io, container=read_bucket1) - with HDF5IO(self.paths[1], manager=get_baz_buildmanager(), mode='r') as read_io: + with HDF5IO(self.paths[1], manager=get_baz_buildmanager(), mode="r") as read_io: read_bucket2 = read_io.read() # remove and check the appended child, then compare the read container with the original @@ -3118,21 +3440,23 @@ def test_export_cpd_dset_refs(self): self.assertContainerEqual(bucket, read_bucket2, ignore_name=True, ignore_hdmf_attrs=True) for i in range(num_bazs): - baz_name = 'baz%d' % i + baz_name = "baz%d" % i self.assertEqual(read_bucket2.baz_cpd_data.data[i][0], i) - self.assertIs(read_bucket2.baz_cpd_data.data[i][1], read_bucket2.bazs[baz_name]) + self.assertIs( + read_bucket2.baz_cpd_data.data[i][1], + read_bucket2.bazs[baz_name], + ) def test_non_manager_container(self): """Test that exporting with a src_io without a manager raises an error.""" - foo1 = Foo('foo1', [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) - foobucket = FooBucket('bucket1', [foo1]) + foo1 = Foo("foo1", [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) + foobucket = FooBucket("bucket1", [foo1]) foofile = FooFile(buckets=[foobucket]) - with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode='w') as write_io: + with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode="w") as write_io: write_io.write(foofile) class OtherIO(HDMFIO): - def read_builder(self): pass @@ -3146,22 +3470,25 @@ def close(self): pass with OtherIO() as read_io: - with HDF5IO(self.paths[1], mode='w') as export_io: - msg = 'When a container is provided, src_io must have a non-None manager (BuildManager) property.' + with HDF5IO(self.paths[1], mode="w") as export_io: + msg = "When a container is provided, src_io must have a non-None manager (BuildManager) property." with self.assertRaisesWith(ValueError, msg): - export_io.export(src_io=read_io, container=foofile, write_args={'link_data': False}) + export_io.export( + src_io=read_io, + container=foofile, + write_args={"link_data": False}, + ) def test_non_HDF5_src_link_data_true(self): """Test that exporting with a src_io without a manager raises an error.""" - foo1 = Foo('foo1', [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) - foobucket = FooBucket('bucket1', [foo1]) + foo1 = Foo("foo1", [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) + foobucket = FooBucket("bucket1", [foo1]) foofile = FooFile(buckets=[foobucket]) - with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode='w') as write_io: + with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode="w") as write_io: write_io.write(foofile) class OtherIO(HDMFIO): - def __init__(self, manager): super().__init__(manager=manager) @@ -3178,102 +3505,107 @@ def close(self): pass with OtherIO(manager=get_foo_buildmanager()) as read_io: - with HDF5IO(self.paths[1], mode='w') as export_io: + with HDF5IO(self.paths[1], mode="w") as export_io: msg = "Cannot export from non-HDF5 backend OtherIO to HDF5 with write argument link_data=True." with self.assertRaisesWith(UnsupportedOperation, msg): export_io.export(src_io=read_io, container=foofile) def test_wrong_mode(self): """Test that exporting with a src_io without a manager raises an error.""" - foo1 = Foo('foo1', [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) - foobucket = FooBucket('bucket1', [foo1]) + foo1 = Foo("foo1", [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) + foobucket = FooBucket("bucket1", [foo1]) foofile = FooFile(buckets=[foobucket]) - with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode='w') as write_io: + with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode="w") as write_io: write_io.write(foofile) - with HDF5IO(self.paths[0], mode='r') as read_io: - with HDF5IO(self.paths[1], mode='a') as export_io: + with HDF5IO(self.paths[0], mode="r") as read_io: + with HDF5IO(self.paths[1], mode="a") as export_io: msg = "Cannot export to file %s in mode 'a'. Please use mode 'w'." % self.paths[1] with self.assertRaisesWith(UnsupportedOperation, msg): export_io.export(src_io=read_io) def test_with_new_id(self): """Test that exporting with a src_io without a manager raises an error.""" - foo1 = Foo('foo1', [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) - foobucket = FooBucket('bucket1', [foo1]) + foo1 = Foo("foo1", [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) + foobucket = FooBucket("bucket1", [foo1]) foofile = FooFile([foobucket]) - with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode='w') as write_io: + with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode="w") as write_io: write_io.write(foofile) - with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode='r') as read_io: + with HDF5IO(self.paths[0], manager=get_foo_buildmanager(), mode="r") as read_io: data = read_io.read() original_id = data.object_id data.generate_new_id() - with HDF5IO(self.paths[1], mode='w') as export_io: + with HDF5IO(self.paths[1], mode="w") as export_io: export_io.export(src_io=read_io, container=data) - with HDF5IO(self.paths[1], manager=get_foo_buildmanager(), mode='r') as read_io: + with HDF5IO(self.paths[1], manager=get_foo_buildmanager(), mode="r") as read_io: data = read_io.read() self.assertTrue(original_id != data.object_id) class TestDatasetRefs(TestCase): - def test_roundtrip(self): self.path = get_temp_filepath() bazs = [] num_bazs = 10 for i in range(num_bazs): - bazs.append(Baz(name='baz%d' % i)) - baz_data = BazData(name='baz_data1', data=bazs) - bucket = BazBucket(name='bucket1', bazs=bazs.copy(), baz_data=baz_data) + bazs.append(Baz(name="baz%d" % i)) + baz_data = BazData(name="baz_data1", data=bazs) + bucket = BazBucket(name="bucket1", bazs=bazs.copy(), baz_data=baz_data) - with HDF5IO(self.path, manager=get_baz_buildmanager(), mode='w') as write_io: + with HDF5IO(self.path, manager=get_baz_buildmanager(), mode="w") as write_io: write_io.write(bucket) - with HDF5IO(self.path, manager=get_baz_buildmanager(), mode='r') as read_io: + with HDF5IO(self.path, manager=get_baz_buildmanager(), mode="r") as read_io: read_bucket = read_io.read() self.assertContainerEqual(bucket, read_bucket, ignore_name=True) for i in range(num_bazs): - baz_name = 'baz%d' % i + baz_name = "baz%d" % i self.assertIs(read_bucket.baz_data.data[i], read_bucket.bazs[baz_name]) class TestCpdDatasetRefs(TestCase): - def test_roundtrip(self): self.path = get_temp_filepath() bazs = [] baz_pairs = [] num_bazs = 10 for i in range(num_bazs): - b = Baz(name='baz%d' % i) + b = Baz(name="baz%d" % i) bazs.append(b) baz_pairs.append((i, b)) - baz_cpd_data = BazCpdData(name='baz_cpd_data1', data=baz_pairs) - bucket = BazBucket(name='bucket1', bazs=bazs.copy(), baz_cpd_data=baz_cpd_data) + baz_cpd_data = BazCpdData(name="baz_cpd_data1", data=baz_pairs) + bucket = BazBucket(name="bucket1", bazs=bazs.copy(), baz_cpd_data=baz_cpd_data) - with HDF5IO(self.path, manager=get_baz_buildmanager(), mode='w') as write_io: + with HDF5IO(self.path, manager=get_baz_buildmanager(), mode="w") as write_io: write_io.write(bucket) - with HDF5IO(self.path, manager=get_baz_buildmanager(), mode='r') as read_io: + with HDF5IO(self.path, manager=get_baz_buildmanager(), mode="r") as read_io: read_bucket = read_io.read() self.assertContainerEqual(bucket, read_bucket, ignore_name=True) for i in range(num_bazs): - baz_name = 'baz%d' % i + baz_name = "baz%d" % i self.assertEqual(read_bucket.baz_cpd_data.data[i][0], i) - self.assertIs(read_bucket.baz_cpd_data.data[i][1], read_bucket.bazs[baz_name]) + self.assertIs( + read_bucket.baz_cpd_data.data[i][1], + read_bucket.bazs[baz_name], + ) -@unittest.skipIf(SKIP_ZARR_TESTS, "Skipping TestRoundTripHDF5withZarrInput because Zarr is not installed") +@unittest.skipIf( + SKIP_ZARR_TESTS, + "Skipping TestRoundTripHDF5withZarrInput because Zarr is not installed", +) class TestWriteHDF5withZarrInput(TestCase): """ Test saving data to HDF5 with a zarr.Array as the data """ + def setUp(self): self.manager = get_foo_buildmanager() self.path = get_temp_filepath() @@ -3289,98 +3621,111 @@ def tearDown(self): def test_roundtrip_basic(self): # Setup all the data we need zarr.save(self.zarr_path, np.arange(50).reshape(5, 10)) - zarr_data = zarr.open(self.zarr_path, 'r') - foo1 = Foo(name='foo1', - my_data=zarr_data, - attr1="I am foo1", - attr2=17, - attr3=3.14) - foobucket = FooBucket('bucket1', [foo1]) + zarr_data = zarr.open(self.zarr_path, "r") + foo1 = Foo( + name="foo1", + my_data=zarr_data, + attr1="I am foo1", + attr2=17, + attr3=3.14, + ) + foobucket = FooBucket("bucket1", [foo1]) foofile = FooFile(buckets=[foobucket]) - with HDF5IO(self.path, manager=self.manager, mode='w') as io: + with HDF5IO(self.path, manager=self.manager, mode="w") as io: io.write(foofile) - with HDF5IO(self.path, manager=self.manager, mode='r') as io: + with HDF5IO(self.path, manager=self.manager, mode="r") as io: read_foofile = io.read() - self.assertListEqual(foofile.buckets['bucket1'].foos['foo1'].my_data[:].tolist(), - read_foofile.buckets['bucket1'].foos['foo1'].my_data[:].tolist()) + self.assertListEqual( + foofile.buckets["bucket1"].foos["foo1"].my_data[:].tolist(), + read_foofile.buckets["bucket1"].foos["foo1"].my_data[:].tolist(), + ) def test_roundtrip_empty_dataset(self): - zarr.save(self.zarr_path, np.asarray([]).astype('int64')) - zarr_data = zarr.open(self.zarr_path, 'r') - foo1 = Foo('foo1', zarr_data, "I am foo1", 17, 3.14) - foobucket = FooBucket('bucket1', [foo1]) + zarr.save(self.zarr_path, np.asarray([]).astype("int64")) + zarr_data = zarr.open(self.zarr_path, "r") + foo1 = Foo("foo1", zarr_data, "I am foo1", 17, 3.14) + foobucket = FooBucket("bucket1", [foo1]) foofile = FooFile(buckets=[foobucket]) - with HDF5IO(self.path, manager=self.manager, mode='w') as io: + with HDF5IO(self.path, manager=self.manager, mode="w") as io: io.write(foofile) - with HDF5IO(self.path, manager=self.manager, mode='r') as io: + with HDF5IO(self.path, manager=self.manager, mode="r") as io: read_foofile = io.read() - self.assertListEqual([], read_foofile.buckets['bucket1'].foos['foo1'].my_data[:].tolist()) + self.assertListEqual( + [], + read_foofile.buckets["bucket1"].foos["foo1"].my_data[:].tolist(), + ) def test_write_zarr_int32_dataset(self): - base_data = np.arange(50).reshape(5, 10).astype('int32') + base_data = np.arange(50).reshape(5, 10).astype("int32") zarr.save(self.zarr_path, base_data) - zarr_data = zarr.open(self.zarr_path, 'r') - io = HDF5IO(self.path, mode='a') + zarr_data = zarr.open(self.zarr_path, "r") + io = HDF5IO(self.path, mode="a") f = io._file - io.write_dataset(f, DatasetBuilder(name='test_dataset', data=zarr_data, attributes={})) - dset = f['test_dataset'] + io.write_dataset( + f, + DatasetBuilder(name="test_dataset", data=zarr_data, attributes={}), + ) + dset = f["test_dataset"] self.assertTupleEqual(dset.shape, base_data.shape) self.assertEqual(dset.dtype, base_data.dtype) self.assertEqual(base_data.dtype, dset.dtype) - self.assertListEqual(dset[:].tolist(), - base_data.tolist()) + self.assertListEqual(dset[:].tolist(), base_data.tolist()) def test_write_zarr_float32_dataset(self): - base_data = np.arange(50).reshape(5, 10).astype('float32') + base_data = np.arange(50).reshape(5, 10).astype("float32") zarr.save(self.zarr_path, base_data) - zarr_data = zarr.open(self.zarr_path, 'r') - io = HDF5IO(self.path, mode='a') + zarr_data = zarr.open(self.zarr_path, "r") + io = HDF5IO(self.path, mode="a") f = io._file - io.write_dataset(f, DatasetBuilder(name='test_dataset', data=zarr_data, attributes={})) - dset = f['test_dataset'] + io.write_dataset( + f, + DatasetBuilder(name="test_dataset", data=zarr_data, attributes={}), + ) + dset = f["test_dataset"] self.assertTupleEqual(dset.shape, base_data.shape) self.assertEqual(dset.dtype, base_data.dtype) self.assertEqual(base_data.dtype, dset.dtype) - self.assertListEqual(dset[:].tolist(), - base_data.tolist()) + self.assertListEqual(dset[:].tolist(), base_data.tolist()) def test_write_zarr_string_dataset(self): - base_data = np.array(['string1', 'string2'], dtype=str) + base_data = np.array(["string1", "string2"], dtype=str) zarr.save(self.zarr_path, base_data) - zarr_data = zarr.open(self.zarr_path, 'r') - io = HDF5IO(self.path, mode='a') + zarr_data = zarr.open(self.zarr_path, "r") + io = HDF5IO(self.path, mode="a") f = io._file - io.write_dataset(f, DatasetBuilder('test_dataset', zarr_data, attributes={})) - dset = f['test_dataset'] + io.write_dataset(f, DatasetBuilder("test_dataset", zarr_data, attributes={})) + dset = f["test_dataset"] self.assertTupleEqual(dset.shape, (2,)) self.assertListEqual(dset[:].astype(bytes).tolist(), base_data.astype(bytes).tolist()) def test_write_zarr_dataset_compress_gzip(self): - base_data = np.arange(50).reshape(5, 10).astype('float32') + base_data = np.arange(50).reshape(5, 10).astype("float32") zarr.save(self.zarr_path, base_data) - zarr_data = zarr.open(self.zarr_path, 'r') - a = H5DataIO(zarr_data, - compression='gzip', - compression_opts=5, - shuffle=True, - fletcher32=True) - io = HDF5IO(self.path, mode='a') + zarr_data = zarr.open(self.zarr_path, "r") + a = H5DataIO( + zarr_data, + compression="gzip", + compression_opts=5, + shuffle=True, + fletcher32=True, + ) + io = HDF5IO(self.path, mode="a") f = io._file - io.write_dataset(f, DatasetBuilder('test_dataset', a, attributes={})) - dset = f['test_dataset'] + io.write_dataset(f, DatasetBuilder("test_dataset", a, attributes={})) + dset = f["test_dataset"] self.assertTrue(np.all(dset[:] == a.data)) - self.assertEqual(dset.compression, 'gzip') + self.assertEqual(dset.compression, "gzip") self.assertEqual(dset.compression_opts, 5) self.assertEqual(dset.shuffle, True) self.assertEqual(dset.fletcher32, True) class HDF5IOEmptyDataset(TestCase): - """ Test if file does not exist, write in mode (w, w-, x, a) is ok """ + """Test if file does not exist, write in mode (w, w-, x, a) is ok""" def setUp(self): self.manager = get_foo_buildmanager() @@ -3395,11 +3740,11 @@ def tearDown(self): def test_write_empty_dataset(self): dataio = H5DataIO(shape=(5,), dtype=int) - foo = Foo('foo1', dataio, "I am foo1", 17, 3.14) - bucket = FooBucket('bucket1', [foo]) + foo = Foo("foo1", dataio, "I am foo1", 17, 3.14) + bucket = FooBucket("bucket1", [foo]) foofile = FooFile(buckets=[bucket]) - with HDF5IO(self.path, manager=self.manager, mode='w') as io: + with HDF5IO(self.path, manager=self.manager, mode="w") as io: io.write(foofile) self.assertIs(foo.my_data, dataio) @@ -3409,23 +3754,22 @@ def test_write_empty_dataset(self): def test_overwrite_dataset(self): dataio = H5DataIO(shape=(5,), dtype=int) - foo = Foo('foo1', dataio, "I am foo1", 17, 3.14) - bucket = FooBucket('bucket1', [foo]) + foo = Foo("foo1", dataio, "I am foo1", 17, 3.14) + bucket = FooBucket("bucket1", [foo]) foofile = FooFile(buckets=[bucket]) - with HDF5IO(self.path, manager=self.manager, mode='w') as io: + with HDF5IO(self.path, manager=self.manager, mode="w") as io: io.write(foofile) - with self.assertRaisesRegex(ValueError, 'Cannot overwrite H5DataIO.dataset'): - with HDF5IO(self.path2, manager=self.manager, mode='w') as io: + with self.assertRaisesRegex(ValueError, "Cannot overwrite H5DataIO.dataset"): + with HDF5IO(self.path2, manager=self.manager, mode="w") as io: io.write(foofile) class HDF5IOClassmethodTests(TestCase): - def setUp(self): self.path = get_temp_filepath() - self.f = h5py.File(self.path, 'w') + self.f = h5py.File(self.path, "w") def tearDown(self): self.f.close() @@ -3433,29 +3777,30 @@ def tearDown(self): os.remove(self.path) def test_setup_empty_dset(self): - dset = HDF5IO.__setup_empty_dset__(self.f, 'foo', {'shape': (3, 3), 'dtype': 'float'}) - self.assertEqual(dset.name, '/foo') + dset = HDF5IO.__setup_empty_dset__(self.f, "foo", {"shape": (3, 3), "dtype": "float"}) + self.assertEqual(dset.name, "/foo") self.assertTupleEqual(dset.shape, (3, 3)) self.assertIs(dset.dtype.type, np.float32) def test_setup_empty_dset_req_args(self): - with self.assertRaisesRegex(ValueError, 'Cannot setup empty dataset /foo without dtype'): - HDF5IO.__setup_empty_dset__(self.f, 'foo', {'shape': (3, 3)}) + with self.assertRaisesRegex(ValueError, "Cannot setup empty dataset /foo without dtype"): + HDF5IO.__setup_empty_dset__(self.f, "foo", {"shape": (3, 3)}) - with self.assertRaisesRegex(ValueError, 'Cannot setup empty dataset /foo without shape'): - HDF5IO.__setup_empty_dset__(self.f, 'foo', {'dtype': np.float32}) + with self.assertRaisesRegex(ValueError, "Cannot setup empty dataset /foo without shape"): + HDF5IO.__setup_empty_dset__(self.f, "foo", {"dtype": np.float32}) def test_setup_empty_dset_create_exception(self): - HDF5IO.__setup_empty_dset__(self.f, 'foo', {'shape': (3, 3), 'dtype': 'float'}) + HDF5IO.__setup_empty_dset__(self.f, "foo", {"shape": (3, 3), "dtype": "float"}) with self.assertRaisesRegex(Exception, "Could not create dataset foo in /"): - HDF5IO.__setup_empty_dset__(self.f, 'foo', {'shape': (3, 3), 'dtype': 'float'}) + HDF5IO.__setup_empty_dset__(self.f, "foo", {"shape": (3, 3), "dtype": "float"}) class H5DataIOTests(TestCase): - def _bad_arg_cm(self): - return self.assertRaisesRegex(ValueError, "Must specify 'dtype' and 'shape' " - "if not specifying 'data'") + return self.assertRaisesRegex( + ValueError, + "Must specify 'dtype' and 'shape' if not specifying 'data'", + ) def test_dataio_bad_args(self): with self._bad_arg_cm(): @@ -3473,5 +3818,8 @@ def test_dataio_len(self): def test_dataio_shape_then_data(self): dataio = H5DataIO(shape=(10, 10), dtype=int) - with self.assertRaisesRegex(ValueError, "Setting data when dtype and shape are not None is not supported"): + with self.assertRaisesRegex( + ValueError, + "Setting data when dtype and shape are not None is not supported", + ): dataio.data = list() diff --git a/tests/unit/test_io_hdf5_streaming.py b/tests/unit/test_io_hdf5_streaming.py index 9729778c7..78cbb7f74 100644 --- a/tests/unit/test_io_hdf5_streaming.py +++ b/tests/unit/test_io_hdf5_streaming.py @@ -1,11 +1,18 @@ -from copy import copy, deepcopy import os import urllib.request +from copy import copy, deepcopy + import h5py -from hdmf.build import TypeMap, BuildManager +from hdmf.build import BuildManager, TypeMap from hdmf.common import get_hdf5io, get_type_map -from hdmf.spec import GroupSpec, DatasetSpec, SpecNamespace, NamespaceBuilder, NamespaceCatalog +from hdmf.spec import ( + DatasetSpec, + GroupSpec, + NamespaceBuilder, + NamespaceCatalog, + SpecNamespace, +) from hdmf.testing import TestCase from hdmf.utils import docval, get_docval @@ -34,8 +41,10 @@ def setUp(self): nwb_container_spec = NWBGroupSpec( neurodata_type_def="NWBContainer", neurodata_type_inc="Container", - doc=("An abstract data type for a generic container storing collections of data and metadata. " - "Base type for all data and metadata containers."), + doc=( + "An abstract data type for a generic container storing collections of data and metadata. " + "Base type for all data and metadata containers." + ), ) subject_spec = NWBGroupSpec( neurodata_type_def="Subject", @@ -88,6 +97,7 @@ def test_basic_read(self): with get_hdf5io(s3_path, "r", manager=self.manager, driver="ros3") as io: io.read() + # Util functions and classes to enable loading of the NWB namespace -- see pynwb/src/pynwb/spec.py @@ -98,20 +108,32 @@ def __swap_inc_def(cls): # do not set default neurodata_type_inc for base hdmf-common types that should not have data_type_inc for arg in args: if arg["name"] == "data_type_def": - ret.append({"name": "neurodata_type_def", "type": str, - "doc": "the NWB data type this spec defines", "default": None}) + ret.append( + { + "name": "neurodata_type_def", + "type": str, + "doc": "the NWB data type this spec defines", + "default": None, + } + ) elif arg["name"] == "data_type_inc": - ret.append({"name": "neurodata_type_inc", "type": (clsname, str), - "doc": "the NWB data type this spec includes", "default": None}) + ret.append( + { + "name": "neurodata_type_inc", + "type": (clsname, str), + "doc": "the NWB data type this spec includes", + "default": None, + } + ) else: ret.append(copy(arg)) return ret class BaseStorageOverride: - """ This class is used for the purpose of overriding - BaseStorageSpec classmethods, without creating diamond - inheritance hierarchies. + """This class is used for the purpose of overriding + BaseStorageSpec classmethods, without creating diamond + inheritance hierarchies. """ __type_key = "neurodata_type" @@ -120,17 +142,17 @@ class BaseStorageOverride: @classmethod def type_key(cls): - """ Get the key used to store data type on an instance""" + """Get the key used to store data type on an instance""" return cls.__type_key @classmethod def inc_key(cls): - """ Get the key used to define a data_type include.""" + """Get the key used to define a data_type include.""" return cls.__inc_key @classmethod def def_key(cls): - """ Get the key used to define a data_type definition.""" + """Get the key used to define a data_type definition.""" return cls.__def_key @property @@ -166,7 +188,7 @@ def _translate_kwargs(cls, kwargs): class NWBDatasetSpec(BaseStorageOverride, DatasetSpec): - """ The Spec class to use for NWB dataset specifications. + """The Spec class to use for NWB dataset specifications. Classes will automatically include NWBData if None is specified. """ @@ -184,7 +206,7 @@ def __init__(self, **kwargs): class NWBGroupSpec(BaseStorageOverride, GroupSpec): - """ The Spec class to use for NWB group specifications. + """The Spec class to use for NWB group specifications. Classes will automatically include NWBContainer if None is specified. """ @@ -205,7 +227,7 @@ def dataset_spec_cls(cls): @docval({"name": "neurodata_type", "type": str, "doc": "the neurodata_type to retrieve"}) def get_neurodata_type(self, **kwargs): - """ Get a specification by "neurodata_type" """ + """Get a specification by "neurodata_type" """ return super().get_data_type(kwargs["neurodata_type"]) diff --git a/tests/unit/test_multicontainerinterface.py b/tests/unit/test_multicontainerinterface.py index 3ebe36773..454040b0e 100644 --- a/tests/unit/test_multicontainerinterface.py +++ b/tests/unit/test_multicontainerinterface.py @@ -6,57 +6,53 @@ class OData(Data): - pass class Foo(MultiContainerInterface): - __clsconf__ = [ { - 'attr': 'containers', - 'add': 'add_container', - 'type': (Container, ), - 'get': 'get_container', + "attr": "containers", + "add": "add_container", + "type": (Container,), + "get": "get_container", }, { - 'attr': 'data', - 'add': 'add_data', - 'type': (Data, OData), + "attr": "data", + "add": "add_data", + "type": (Data, OData), }, { - 'attr': 'foo_data', - 'add': 'add_foo_data', - 'type': OData, - 'create': 'create_foo_data', + "attr": "foo_data", + "add": "add_foo_data", + "type": OData, + "create": "create_foo_data", }, { - 'attr': 'things', - 'add': 'add_thing', - 'type': (Container, Data, OData), + "attr": "things", + "add": "add_thing", + "type": (Container, Data, OData), }, ] class FooSingle(MultiContainerInterface): - __clsconf__ = { - 'attr': 'containers', - 'add': 'add_container', - 'type': (Container, ), + "attr": "containers", + "add": "add_container", + "type": (Container,), } class Baz(MultiContainerInterface): - __containers = dict() __clsconf__ = [ { - 'attr': 'containers', - 'add': 'add_container', - 'type': Container, - 'get': 'get_container', + "attr": "containers", + "add": "add_container", + "type": Container, + "get": "get_container", }, ] @@ -64,7 +60,7 @@ class Baz(MultiContainerInterface): def __init__(self, name, other_arg, my_containers): super().__init__(name=name) self.other_arg = other_arg - self.containers = {'my ' + v.name: v for v in my_containers} + self.containers = {"my " + v.name: v for v in my_containers} @property def containers(self): @@ -76,61 +72,60 @@ def containers(self, value): class TestBasic(TestCase): - def test_init_docval(self): """Test that the docval for the __init__ method is set correctly.""" dv = get_docval(Foo.__init__) - self.assertEqual(dv[0]['name'], 'containers') - self.assertEqual(dv[1]['name'], 'data') - self.assertEqual(dv[2]['name'], 'foo_data') - self.assertEqual(dv[3]['name'], 'things') - self.assertTupleEqual(dv[0]['type'], (list, tuple, dict, Container)) - self.assertTupleEqual(dv[1]['type'], (list, tuple, dict, Data, OData)) - self.assertTupleEqual(dv[2]['type'], (list, tuple, dict, OData)) - self.assertTupleEqual(dv[3]['type'], (list, tuple, dict, Container, Data, OData)) - self.assertEqual(dv[0]['doc'], 'Container to store in this interface') - self.assertEqual(dv[1]['doc'], 'Data or OData to store in this interface') - self.assertEqual(dv[2]['doc'], 'OData to store in this interface') - self.assertEqual(dv[3]['doc'], 'Container, Data, or OData to store in this interface') + self.assertEqual(dv[0]["name"], "containers") + self.assertEqual(dv[1]["name"], "data") + self.assertEqual(dv[2]["name"], "foo_data") + self.assertEqual(dv[3]["name"], "things") + self.assertTupleEqual(dv[0]["type"], (list, tuple, dict, Container)) + self.assertTupleEqual(dv[1]["type"], (list, tuple, dict, Data, OData)) + self.assertTupleEqual(dv[2]["type"], (list, tuple, dict, OData)) + self.assertTupleEqual(dv[3]["type"], (list, tuple, dict, Container, Data, OData)) + self.assertEqual(dv[0]["doc"], "Container to store in this interface") + self.assertEqual(dv[1]["doc"], "Data or OData to store in this interface") + self.assertEqual(dv[2]["doc"], "OData to store in this interface") + self.assertEqual(dv[3]["doc"], "Container, Data, or OData to store in this interface") for i in range(4): - self.assertDictEqual(dv[i]['default'], {}) - self.assertEqual(dv[4]['name'], 'name') - self.assertEqual(dv[4]['type'], str) - self.assertEqual(dv[4]['doc'], 'the name of this container') - self.assertEqual(dv[4]['default'], 'Foo') + self.assertDictEqual(dv[i]["default"], {}) + self.assertEqual(dv[4]["name"], "name") + self.assertEqual(dv[4]["type"], str) + self.assertEqual(dv[4]["doc"], "the name of this container") + self.assertEqual(dv[4]["default"], "Foo") def test_add_docval(self): """Test that the docval for the add method is set correctly.""" expected_doc = "add_container(containers)\n\nAdd one or multiple Container objects to this Foo" self.assertTrue(Foo.add_container.__doc__.startswith(expected_doc)) dv = get_docval(Foo.add_container) - self.assertEqual(dv[0]['name'], 'containers') - self.assertTupleEqual(dv[0]['type'], (list, tuple, dict, Container)) - self.assertEqual(dv[0]['doc'], 'one or multiple Container objects to add to this Foo') - self.assertFalse('default' in dv[0]) + self.assertEqual(dv[0]["name"], "containers") + self.assertTupleEqual(dv[0]["type"], (list, tuple, dict, Container)) + self.assertEqual(dv[0]["doc"], "one or multiple Container objects to add to this Foo") + self.assertFalse("default" in dv[0]) def test_create_docval(self): """Test that the docval for the create method is set correctly.""" dv = get_docval(Foo.create_foo_data) - self.assertEqual(dv[0]['name'], 'name') - self.assertEqual(dv[1]['name'], 'data') + self.assertEqual(dv[0]["name"], "name") + self.assertEqual(dv[1]["name"], "data") def test_getter_docval(self): """Test that the docval for the get method is set correctly.""" dv = get_docval(Foo.get_container) - self.assertEqual(dv[0]['doc'], 'the name of the Container') - self.assertIsNone(dv[0]['default']) + self.assertEqual(dv[0]["doc"], "the name of the Container") + self.assertIsNone(dv[0]["default"]) def test_getitem_docval(self): """Test that the docval for __getitem__ is set correctly.""" dv = get_docval(Baz.__getitem__) - self.assertEqual(dv[0]['doc'], 'the name of the Container') - self.assertIsNone(dv[0]['default']) + self.assertEqual(dv[0]["doc"], "the name of the Container") + self.assertIsNone(dv[0]["default"]) def test_attr_property(self): """Test that a property is created for the attribute.""" properties = inspect.getmembers(Foo, lambda o: isinstance(o, property)) - match = [p for p in properties if p[0] == 'containers'] + match = [p for p in properties if p[0] == "containers"] self.assertEqual(len(match), 1) def test_attr_getter(self): @@ -142,33 +137,33 @@ def test_init_empty(self): """Test that initializing the MCI with no arguments initializes the attribute dict empty.""" foo = Foo() self.assertDictEqual(foo.containers, {}) - self.assertEqual(foo.name, 'Foo') + self.assertEqual(foo.name, "Foo") def test_init_multi(self): """Test that initializing the MCI with no arguments initializes the attribute dict empty.""" - obj1 = Container('obj1') - data1 = Data('data1', [1, 2, 3]) + obj1 = Container("obj1") + data1 = Data("data1", [1, 2, 3]) foo = Foo(containers=obj1, data=data1) - self.assertDictEqual(foo.containers, {'obj1': obj1}) - self.assertDictEqual(foo.data, {'data1': data1}) + self.assertDictEqual(foo.containers, {"obj1": obj1}) + self.assertDictEqual(foo.data, {"data1": data1}) def test_init_custom_name(self): """Test that initializing the MCI with a custom name works.""" - foo = Foo(name='test_foo') - self.assertEqual(foo.name, 'test_foo') + foo = Foo(name="test_foo") + self.assertEqual(foo.name, "test_foo") # init, create, and setter calls add, so just test add def test_add_single(self): """Test that adding a container to the attribute dict correctly adds the container.""" - obj1 = Container('obj1') + obj1 = Container("obj1") foo = Foo() foo.add_container(obj1) - self.assertDictEqual(foo.containers, {'obj1': obj1}) + self.assertDictEqual(foo.containers, {"obj1": obj1}) self.assertIs(obj1.parent, foo) def test_add_single_modified(self): """Test that adding a container to the attribute dict correctly makes the MCI as modified.""" - obj1 = Container('obj1') + obj1 = Container("obj1") foo = Foo() foo.set_modified(False) # set to False so that we can test whether add_container makes it True foo.add_container(obj1) @@ -176,18 +171,18 @@ def test_add_single_modified(self): def test_add_single_not_parent(self): """Test that adding a container with a parent to the attribute dict correctly adds the container.""" - obj1 = Container('obj1') - obj2 = Container('obj2') + obj1 = Container("obj1") + obj2 = Container("obj2") obj1.parent = obj2 foo = Foo() foo.add_container(obj1) - self.assertDictEqual(foo.containers, {'obj1': obj1}) + self.assertDictEqual(foo.containers, {"obj1": obj1}) self.assertIs(obj1.parent, obj2) def test_add_single_not_parent_modified(self): """Test that adding a container with a parent to the attribute dict correctly marks the MCI as modified.""" - obj1 = Container('obj1') - obj2 = Container('obj2') + obj1 = Container("obj1") + obj2 = Container("obj2") obj1.parent = obj2 foo = Foo() foo.set_modified(False) # set to False so that we can test whether add_container makes it True @@ -196,7 +191,7 @@ def test_add_single_not_parent_modified(self): def test_add_single_dup(self): """Test that adding a container to the attribute dict correctly adds the container.""" - obj1 = Container('obj1') + obj1 = Container("obj1") foo = Foo(obj1) msg = "'obj1' already exists in Foo 'Foo'" with self.assertRaisesWith(ValueError, msg): @@ -204,42 +199,42 @@ def test_add_single_dup(self): def test_add_list(self): """Test that adding a list to the attribute dict correctly adds the items.""" - obj1 = Container('obj1') - obj2 = Container('obj2') + obj1 = Container("obj1") + obj2 = Container("obj2") foo = Foo() foo.add_container([obj1, obj2]) - self.assertDictEqual(foo.containers, {'obj1': obj1, 'obj2': obj2}) + self.assertDictEqual(foo.containers, {"obj1": obj1, "obj2": obj2}) def test_add_dict(self): """Test that adding a dict to the attribute dict correctly adds the input dict values.""" - obj1 = Container('obj1') - obj2 = Container('obj2') + obj1 = Container("obj1") + obj2 = Container("obj2") foo = Foo() - foo.add_container({'a': obj1, 'b': obj2}) - self.assertDictEqual(foo.containers, {'obj1': obj1, 'obj2': obj2}) + foo.add_container({"a": obj1, "b": obj2}) + self.assertDictEqual(foo.containers, {"obj1": obj1, "obj2": obj2}) def test_attr_setter_none(self): """Test that setting the attribute dict to None does not alter the dict.""" - obj1 = Container('obj1') + obj1 = Container("obj1") foo = Foo(obj1) foo.containers = None - self.assertDictEqual(foo.containers, {'obj1': obj1}) + self.assertDictEqual(foo.containers, {"obj1": obj1}) def test_remove_child(self): """Test that removing a child container from the attribute dict resets the parent to None.""" - obj1 = Container('obj1') + obj1 = Container("obj1") foo = Foo(obj1) - del foo.containers['obj1'] + del foo.containers["obj1"] self.assertDictEqual(foo.containers, {}) self.assertIsNone(obj1.parent) def test_remove_non_child(self): """Test that removing a non-child container from the attribute dict resets the parent to None.""" - obj1 = Container('obj1') - obj2 = Container('obj2') + obj1 = Container("obj1") + obj2 = Container("obj2") obj1.parent = obj2 foo = Foo(obj1) - del foo.containers['obj1'] + del foo.containers["obj1"] self.assertDictEqual(foo.containers, {}) self.assertIs(obj1.parent, obj2) @@ -252,14 +247,14 @@ def test_getter_empty(self): def test_getter_none(self): """Test that calling the getter with no args and one item in the attribute returns the item.""" - obj1 = Container('obj1') + obj1 = Container("obj1") foo = Foo(obj1) self.assertIs(foo.get_container(), obj1) def test_getter_none_multiple(self): """Test that calling the getter with no args and multiple items in the attribute dict raises an error.""" - obj1 = Container('obj1') - obj2 = Container('obj2') + obj1 = Container("obj1") + obj2 = Container("obj2") foo = Foo([obj1, obj2]) msg = "More than one element in containers of Foo 'Foo' -- must specify a name." with self.assertRaisesWith(ValueError, msg): @@ -267,33 +262,33 @@ def test_getter_none_multiple(self): def test_getter_name(self): """Test that calling the getter with a correct key works.""" - obj1 = Container('obj1') + obj1 = Container("obj1") foo = Foo(obj1) - self.assertIs(foo.get_container('obj1'), obj1) + self.assertIs(foo.get_container("obj1"), obj1) def test_getter_name_not_found(self): """Test that calling the getter with a key not in the attribute dict raises a KeyError.""" foo = Foo() msg = "\"'obj1' not found in containers of Foo 'Foo'.\"" with self.assertRaisesWith(KeyError, msg): - foo.get_container('obj1') + foo.get_container("obj1") def test_getitem_multiconf(self): """Test that classes with multiple attribute configurations cannot use getitem.""" foo = Foo() msg = "'Foo' object is not subscriptable" with self.assertRaisesWith(TypeError, msg): - foo['aa'] + foo["aa"] def test_getitem(self): """Test that getitem works.""" - obj1 = Container('obj1') + obj1 = Container("obj1") foo = FooSingle(obj1) - self.assertIs(foo['obj1'], obj1) + self.assertIs(foo["obj1"], obj1) def test_getitem_single_none(self): """Test that getitem works wwhen there is a single item and no name is given to getitem.""" - obj1 = Container('obj1') + obj1 = Container("obj1") foo = FooSingle(obj1) self.assertIs(foo[None], obj1) @@ -306,8 +301,8 @@ def test_getitem_empty(self): def test_getitem_multiple(self): """Test that an error is raised if the attribute dict has multiple values and no name is given to getitem.""" - obj1 = Container('obj1') - obj2 = Container('obj2') + obj1 = Container("obj1") + obj2 = Container("obj2") foo = FooSingle([obj1, obj2]) msg = "More than one Container in FooSingle 'FooSingle' -- must specify a name." with self.assertRaisesWith(ValueError, msg): @@ -315,45 +310,43 @@ def test_getitem_multiple(self): def test_getitem_not_found(self): """Test that a KeyError is raised if the key is not found using getitem.""" - obj1 = Container('obj1') + obj1 = Container("obj1") foo = FooSingle(obj1) msg = "\"'obj2' not found in FooSingle 'FooSingle'.\"" with self.assertRaisesWith(KeyError, msg): - foo['obj2'] + foo["obj2"] class TestOverrideInit(TestCase): - def test_override_init(self): """Test that overriding __init__ works.""" - obj1 = Container('obj1') - obj2 = Container('obj2') + obj1 = Container("obj1") + obj2 = Container("obj2") containers = [obj1, obj2] - baz = Baz(name='test_baz', other_arg=1, my_containers=containers) - self.assertEqual(baz.name, 'test_baz') + baz = Baz(name="test_baz", other_arg=1, my_containers=containers) + self.assertEqual(baz.name, "test_baz") self.assertEqual(baz.other_arg, 1) def test_override_property(self): """Test that overriding the attribute property works.""" - obj1 = Container('obj1') - obj2 = Container('obj2') + obj1 = Container("obj1") + obj2 = Container("obj2") containers = [obj1, obj2] - baz = Baz(name='test_baz', other_arg=1, my_containers=containers) - self.assertDictEqual(baz.containers, {'my obj1': obj1, 'my obj2': obj2}) + baz = Baz(name="test_baz", other_arg=1, my_containers=containers) + self.assertDictEqual(baz.containers, {"my obj1": obj1, "my obj2": obj2}) self.assertFalse(isinstance(baz.containers, LabelledDict)) - self.assertIs(baz.get_container('my obj1'), obj1) + self.assertIs(baz.get_container("my obj1"), obj1) baz.containers = {} self.assertDictEqual(baz.containers, {}) class TestNoClsConf(TestCase): - def test_mci_init(self): """Test that MultiContainerInterface cannot be instantiated.""" msg = "Can't instantiate class MultiContainerInterface." with self.assertRaisesWith(TypeError, msg): - MultiContainerInterface(name='a') + MultiContainerInterface(name="a") def test_init_no_cls_conf(self): """Test that defining an MCI subclass without __clsconf__ raises an error.""" @@ -361,10 +354,12 @@ def test_init_no_cls_conf(self): class Bar(MultiContainerInterface): pass - msg = ("MultiContainerInterface subclass Bar is missing __clsconf__ attribute. Please check that " - "the class is properly defined.") + msg = ( + "MultiContainerInterface subclass Bar is missing __clsconf__ attribute." + " Please check that the class is properly defined." + ) with self.assertRaisesWith(TypeError, msg): - Bar(name='a') + Bar(name="a") def test_init_superclass_no_cls_conf(self): """Test that a subclass of an MCI class without a __clsconf__ can be initialized.""" @@ -373,20 +368,18 @@ class Bar(MultiContainerInterface): pass class Qux(Bar): - __clsconf__ = { - 'attr': 'containers', - 'add': 'add_container', - 'type': Container, + "attr": "containers", + "add": "add_container", + "type": Container, } - obj1 = Container('obj1') + obj1 = Container("obj1") qux = Qux(obj1) - self.assertDictEqual(qux.containers, {'obj1': obj1}) + self.assertDictEqual(qux.containers, {"obj1": obj1}) class TestBadClsConf(TestCase): - def test_wrong_type(self): """Test that an error is raised if __clsconf__ is missing the add key.""" @@ -394,12 +387,11 @@ def test_wrong_type(self): with self.assertRaisesWith(TypeError, msg): class Bar(MultiContainerInterface): - __clsconf__ = ( { - 'attr': 'data', - 'add': 'add_data', - 'type': (Data, ), + "attr": "data", + "add": "add_data", + "type": (Data,), }, ) @@ -410,7 +402,6 @@ def test_missing_add(self): with self.assertRaisesWith(ValueError, msg): class Bar(MultiContainerInterface): - __clsconf__ = {} def test_missing_attr(self): @@ -420,9 +411,8 @@ def test_missing_attr(self): with self.assertRaisesWith(ValueError, msg): class Bar(MultiContainerInterface): - __clsconf__ = { - 'add': 'add_container', + "add": "add_container", } def test_missing_type(self): @@ -432,26 +422,26 @@ def test_missing_type(self): with self.assertRaisesWith(ValueError, msg): class Bar(MultiContainerInterface): - __clsconf__ = { - 'add': 'add_container', - 'attr': 'containers', + "add": "add_container", + "attr": "containers", } def test_create_multiple_types(self): """Test that an error is raised if __clsconf__ specifies 'create' key with multiple types.""" - msg = ("Cannot specify 'create' key in __clsconf__ for MultiContainerInterface subclass Bar " - "when 'type' key is not a single type") + msg = ( + "Cannot specify 'create' key in __clsconf__ for MultiContainerInterface" + " subclass Bar when 'type' key is not a single type" + ) with self.assertRaisesWith(ValueError, msg): class Bar(MultiContainerInterface): - __clsconf__ = { - 'attr': 'data', - 'add': 'add_data', - 'type': (Data, ), - 'create': 'create_data', + "attr": "data", + "add": "add_data", + "type": (Data,), + "create": "create_data", } def test_missing_add_multi(self): @@ -461,14 +451,13 @@ def test_missing_add_multi(self): with self.assertRaisesWith(ValueError, msg): class Bar(MultiContainerInterface): - __clsconf__ = [ { - 'attr': 'data', - 'add': 'add_data', - 'type': (Data, ), + "attr": "data", + "add": "add_data", + "type": (Data,), }, - {} + {}, ] def test_missing_attr_multi(self): @@ -478,16 +467,15 @@ def test_missing_attr_multi(self): with self.assertRaisesWith(ValueError, msg): class Bar(MultiContainerInterface): - __clsconf__ = [ { - 'attr': 'data', - 'add': 'add_data', - 'type': (Data, ), + "attr": "data", + "add": "add_data", + "type": (Data,), }, { - 'add': 'add_container', - } + "add": "add_container", + }, ] def test_missing_type_multi(self): @@ -497,38 +485,38 @@ def test_missing_type_multi(self): with self.assertRaisesWith(ValueError, msg): class Bar(MultiContainerInterface): - __clsconf__ = [ { - 'attr': 'data', - 'add': 'add_data', - 'type': (Data, ), + "attr": "data", + "add": "add_data", + "type": (Data,), }, { - 'add': 'add_container', - 'attr': 'containers', - } + "add": "add_container", + "attr": "containers", + }, ] def test_create_multiple_types_multi(self): """Test that an error is raised if one item of a __clsconf__ list specifies 'create' key with multiple types.""" - msg = ("Cannot specify 'create' key in __clsconf__ for MultiContainerInterface subclass Bar " - "when 'type' key is not a single type at index 1") + msg = ( + "Cannot specify 'create' key in __clsconf__ for MultiContainerInterface" + " subclass Bar when 'type' key is not a single type at index 1" + ) with self.assertRaisesWith(ValueError, msg): class Bar(MultiContainerInterface): - __clsconf__ = [ { - 'attr': 'data', - 'add': 'add_data', - 'type': (Data, ), + "attr": "data", + "add": "add_data", + "type": (Data,), }, { - 'add': 'add_container', - 'attr': 'containers', - 'type': (Container, ), - 'create': 'create_container', - } + "add": "add_container", + "attr": "containers", + "type": (Container,), + "create": "create_container", + }, ] diff --git a/tests/unit/test_query.py b/tests/unit/test_query.py index b2ff267a7..67aae404b 100644 --- a/tests/unit/test_query.py +++ b/tests/unit/test_query.py @@ -3,16 +3,16 @@ import numpy as np from h5py import File -from hdmf.array import SortedArray, LinSpace + +from hdmf.array import LinSpace, SortedArray from hdmf.query import HDMFDataset, Query from hdmf.testing import TestCase class AbstractQueryMixin(metaclass=ABCMeta): - @abstractmethod def getDataset(self): - raise NotImplementedError('Cannot run test unless getDataset is implemented') + raise NotImplementedError("Cannot run test unless getDataset is implemented") def setUp(self): self.dset = self.getDataset() @@ -23,89 +23,138 @@ def test_get_dataset(self): self.assertIsInstance(array, SortedArray) def test___gt__(self): - ''' + """ Test wrapper greater than magic method - ''' + """ q = self.wrapper > 5 self.assertIsInstance(q, Query) result = q.evaluate() - expected = [False, False, False, False, False, - False, True, True, True, True] + expected = [ + False, + False, + False, + False, + False, + False, + True, + True, + True, + True, + ] expected = slice(6, 10) self.assertEqual(result, expected) def test___ge__(self): - ''' + """ Test wrapper greater than or equal magic method - ''' + """ q = self.wrapper >= 5 self.assertIsInstance(q, Query) result = q.evaluate() - expected = [False, False, False, False, False, - True, True, True, True, True] + expected = [ + False, + False, + False, + False, + False, + True, + True, + True, + True, + True, + ] expected = slice(5, 10) self.assertEqual(result, expected) def test___lt__(self): - ''' + """ Test wrapper less than magic method - ''' + """ q = self.wrapper < 5 self.assertIsInstance(q, Query) result = q.evaluate() - expected = [True, True, True, True, True, - False, False, False, False, False] + expected = [ + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + ] expected = slice(0, 5) self.assertEqual(result, expected) def test___le__(self): - ''' + """ Test wrapper less than or equal magic method - ''' + """ q = self.wrapper <= 5 self.assertIsInstance(q, Query) result = q.evaluate() - expected = [True, True, True, True, True, - True, False, False, False, False] + expected = [ + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + ] expected = slice(0, 6) self.assertEqual(result, expected) def test___eq__(self): - ''' + """ Test wrapper equals magic method - ''' + """ q = self.wrapper == 5 self.assertIsInstance(q, Query) result = q.evaluate() - expected = [False, False, False, False, False, - True, False, False, False, False] + expected = [ + False, + False, + False, + False, + False, + True, + False, + False, + False, + False, + ] expected = 5 self.assertTrue(np.array_equal(result, expected)) def test___ne__(self): - ''' + """ Test wrapper not equal magic method - ''' + """ q = self.wrapper != 5 self.assertIsInstance(q, Query) result = q.evaluate() - expected = [True, True, True, True, True, - False, True, True, True, True] + expected = [True, True, True, True, True, False, True, True, True, True] expected = [slice(0, 5), slice(6, 10)] self.assertTrue(np.array_equal(result, expected)) def test___getitem__(self): - ''' + """ Test wrapper getitem using slice - ''' + """ result = self.wrapper[0:5] expected = [0, 1, 2, 3, 4] self.assertTrue(np.array_equal(result, expected)) def test___getitem__query(self): - ''' + """ Test wrapper getitem using query - ''' + """ q = self.wrapper < 5 result = self.wrapper[q] expected = [0, 1, 2, 3, 4] @@ -113,13 +162,12 @@ def test___getitem__query(self): class SortedQueryTest(AbstractQueryMixin, TestCase): - - path = 'SortedQueryTest.h5' + path = "SortedQueryTest.h5" def getDataset(self): - self.f = File(self.path, 'w') + self.f = File(self.path, "w") self.input = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] - self.d = self.f.create_dataset('dset', data=self.input) + self.d = self.f.create_dataset("dset", data=self.input) return SortedArray(self.d) def tearDown(self): @@ -129,15 +177,13 @@ def tearDown(self): class LinspaceQueryTest(AbstractQueryMixin, TestCase): - - path = 'LinspaceQueryTest.h5' + path = "LinspaceQueryTest.h5" def getDataset(self): return LinSpace(0, 10, 1) class CompoundQueryTest(TestCase): - def getM(self): return SortedArray(np.arange(10, 20, 1)) diff --git a/tests/unit/test_table.py b/tests/unit/test_table.py index 9bb857627..35802ae09 100644 --- a/tests/unit/test_table.py +++ b/tests/unit/test_table.py @@ -1,61 +1,59 @@ import pandas as pd -from hdmf.container import Table, Row, RowGetter +from hdmf.container import Row, RowGetter, Table from hdmf.testing import TestCase class TestTable(TestCase): - @classmethod def get_table_class(cls): class MyTable(Table): - - __defaultname__ = 'my_table' + __defaultname__ = "my_table" __columns__ = [ - {'name': 'col1', 'type': str, 'help': 'a string column'}, - {'name': 'col2', 'type': int, 'help': 'an integer column'}, + {"name": "col1", "type": str, "help": "a string column"}, + {"name": "col2", "type": int, "help": "an integer column"}, ] + return MyTable def test_init(self): MyTable = TestTable.get_table_class() - table = MyTable('test_table') - self.assertTrue(hasattr(table, '__colidx__')) - self.assertEqual(table.__colidx__, {'col1': 0, 'col2': 1}) + table = MyTable("test_table") + self.assertTrue(hasattr(table, "__colidx__")) + self.assertEqual(table.__colidx__, {"col1": 0, "col2": 1}) def test_add_row_getitem(self): MyTable = TestTable.get_table_class() - table = MyTable('test_table') - table.add_row(col1='foo', col2=100) - table.add_row(col1='bar', col2=200) + table = MyTable("test_table") + table.add_row(col1="foo", col2=100) + table.add_row(col1="bar", col2=200) row1 = table[0] row2 = table[1] - self.assertEqual(row1, ('foo', 100)) - self.assertEqual(row2, ('bar', 200)) + self.assertEqual(row1, ("foo", 100)) + self.assertEqual(row2, ("bar", 200)) def test_to_dataframe(self): MyTable = TestTable.get_table_class() - table = MyTable('test_table') - table.add_row(col1='foo', col2=100) - table.add_row(col1='bar', col2=200) + table = MyTable("test_table") + table.add_row(col1="foo", col2=100) + table.add_row(col1="bar", col2=200) df = table.to_dataframe() - exp = pd.DataFrame(data=[{'col1': 'foo', 'col2': 100}, {'col1': 'bar', 'col2': 200}]) + exp = pd.DataFrame(data=[{"col1": "foo", "col2": 100}, {"col1": "bar", "col2": 200}]) pd.testing.assert_frame_equal(df, exp) def test_from_dataframe(self): MyTable = TestTable.get_table_class() - exp = pd.DataFrame(data=[{'col1': 'foo', 'col2': 100}, {'col1': 'bar', 'col2': 200}]) + exp = pd.DataFrame(data=[{"col1": "foo", "col2": 100}, {"col1": "bar", "col2": 200}]) table = MyTable.from_dataframe(exp) row1 = table[0] row2 = table[1] - self.assertEqual(row1, ('foo', 100)) - self.assertEqual(row2, ('bar', 200)) + self.assertEqual(row1, ("foo", 100)) + self.assertEqual(row2, ("bar", 200)) class TestRow(TestCase): - def setUp(self): self.MyTable = TestTable.get_table_class() @@ -64,29 +62,30 @@ class MyRow(Row): self.MyRow = MyRow - self.table = self.MyTable('test_table') + self.table = self.MyTable("test_table") def test_row_no_table(self): - with self.assertRaisesRegex(ValueError, '__table__ must be set if sub-classing Row'): + with self.assertRaisesRegex(ValueError, "__table__ must be set if sub-classing Row"): + class MyRow(Row): pass def test_table_init(self): MyTable = TestTable.get_table_class() - table = MyTable('test_table') - self.assertFalse(hasattr(table, 'row')) + table = MyTable("test_table") + self.assertFalse(hasattr(table, "row")) - table_w_row = self.MyTable('test_table') - self.assertTrue(hasattr(table_w_row, 'row')) + table_w_row = self.MyTable("test_table") + self.assertTrue(hasattr(table_w_row, "row")) self.assertIsInstance(table_w_row.row, RowGetter) self.assertIs(table_w_row.row.table, table_w_row) def test_init(self): - row1 = self.MyRow(col1='foo', col2=100, table=self.table) + row1 = self.MyRow(col1="foo", col2=100, table=self.table) # make sure Row object set up properly self.assertEqual(row1.idx, 0) - self.assertEqual(row1.col1, 'foo') + self.assertEqual(row1.col1, "foo") self.assertEqual(row1.col2, 100) # make sure Row object is stored in Table properly @@ -94,19 +93,19 @@ def test_init(self): self.assertEqual(tmp_row1, row1) def test_add_row_getitem(self): - self.table.add_row(col1='foo', col2=100) - self.table.add_row(col1='bar', col2=200) + self.table.add_row(col1="foo", col2=100) + self.table.add_row(col1="bar", col2=200) row1 = self.table.row[0] self.assertIsInstance(row1, self.MyRow) self.assertEqual(row1.idx, 0) - self.assertEqual(row1.col1, 'foo') + self.assertEqual(row1.col1, "foo") self.assertEqual(row1.col2, 100) row2 = self.table.row[1] self.assertIsInstance(row2, self.MyRow) self.assertEqual(row2.idx, 1) - self.assertEqual(row2.col1, 'bar') + self.assertEqual(row2.col1, "bar") self.assertEqual(row2.col2, 200) # test memoization @@ -114,11 +113,11 @@ def test_add_row_getitem(self): self.assertIs(row3, row1) def test_todict(self): - row1 = self.MyRow(col1='foo', col2=100, table=self.table) - self.assertEqual(row1.todict(), {'col1': 'foo', 'col2': 100}) + row1 = self.MyRow(col1="foo", col2=100, table=self.table) + self.assertEqual(row1.todict(), {"col1": "foo", "col2": 100}) def test___str__(self): - row1 = self.MyRow(col1='foo', col2=100, table=self.table) + row1 = self.MyRow(col1="foo", col2=100, table=self.table) row1_str = str(row1) expected_str = "Row(0, test_table) = {'col1': 'foo', 'col2': 100}" self.assertEqual(row1_str, expected_str) diff --git a/tests/unit/utils_test/test_core_DataChunk.py b/tests/unit/utils_test/test_core_DataChunk.py index 8ad4f7315..1c1f0cc7c 100644 --- a/tests/unit/utils_test/test_core_DataChunk.py +++ b/tests/unit/utils_test/test_core_DataChunk.py @@ -1,12 +1,12 @@ from copy import copy, deepcopy import numpy as np + from hdmf.data_utils import DataChunk from hdmf.testing import TestCase class DataChunkTests(TestCase): - def setUp(self): pass @@ -29,7 +29,7 @@ def test_datachunk_deepcopy(self): def test_datachunk_astype(self): obj = DataChunk(data=np.arange(3), selection=np.s_[0:3]) - newtype = np.dtype('int16') + newtype = np.dtype("int16") obj_astype = obj.astype(newtype) self.assertNotEqual(id(obj), id(obj_astype)) self.assertEqual(obj_astype.dtype, np.dtype(newtype)) diff --git a/tests/unit/utils_test/test_core_DataChunkIterator.py b/tests/unit/utils_test/test_core_DataChunkIterator.py index d24e34bd7..93c635ef7 100644 --- a/tests/unit/utils_test/test_core_DataChunkIterator.py +++ b/tests/unit/utils_test/test_core_DataChunkIterator.py @@ -1,11 +1,10 @@ import numpy as np -from hdmf.data_utils import DataChunkIterator, DataChunk +from hdmf.data_utils import DataChunk, DataChunkIterator from hdmf.testing import TestCase class DataChunkIteratorTests(TestCase): - def setUp(self): pass @@ -13,11 +12,10 @@ def tearDown(self): pass def test_none_iter(self): - """Test that DataChunkIterator __init__ sets defaults correctly and all chunks and recommended shapes are None. - """ - dci = DataChunkIterator(dtype=np.dtype('int')) + """Test that DataChunkIterator __init__ sets defaults and all chunks and recommended shapes are None.""" + dci = DataChunkIterator(dtype=np.dtype("int")) self.assertIsNone(dci.maxshape) - self.assertEqual(dci.dtype, np.dtype('int')) + self.assertEqual(dci.dtype, np.dtype("int")) self.assertEqual(dci.buffer_size, 1) self.assertEqual(dci.iter_axis, 0) count = 0 @@ -28,20 +26,20 @@ def test_none_iter(self): self.assertIsNone(dci.recommended_chunk_shape()) def test_list_none(self): - """Test that DataChunkIterator has no dtype or chunks when given a list of None. - """ + """Test that DataChunkIterator has no dtype or chunks when given a list of None.""" a = [None, None, None] - with self.assertRaisesWith(Exception, 'Data type could not be determined. Please specify dtype in ' - 'DataChunkIterator init.'): + with self.assertRaisesWith( + Exception, + "Data type could not be determined. Please specify dtype in DataChunkIterator init.", + ): DataChunkIterator(a) def test_list_none_dtype(self): - """Test that DataChunkIterator has the passed-in dtype and no chunks when given a list of None. - """ + """Test that DataChunkIterator has the passed-in dtype and no chunks when given a list of None.""" a = [None, None, None] - dci = DataChunkIterator(a, dtype=np.dtype('int')) + dci = DataChunkIterator(a, dtype=np.dtype("int")) self.assertTupleEqual(dci.maxshape, (3,)) - self.assertEqual(dci.dtype, np.dtype('int')) + self.assertEqual(dci.dtype, np.dtype("int")) count = 0 for chunk in dci: pass @@ -50,8 +48,7 @@ def test_list_none_dtype(self): self.assertIsNone(dci.recommended_chunk_shape()) def test_numpy_iter_unbuffered_first_axis(self): - """Test DataChunkIterator with numpy data, no buffering, and iterating on the first dimension. - """ + """Test DataChunkIterator with numpy data, no buffering, and iterating on the first dimension.""" a = np.arange(30).reshape(5, 2, 3) dci = DataChunkIterator(data=a, buffer_size=1) count = 0 @@ -63,8 +60,7 @@ def test_numpy_iter_unbuffered_first_axis(self): self.assertIsNone(dci.recommended_chunk_shape()) def test_numpy_iter_unbuffered_middle_axis(self): - """Test DataChunkIterator with numpy data, no buffering, and iterating on a middle dimension. - """ + """Test DataChunkIterator with numpy data, no buffering, and iterating on a middle dimension.""" a = np.arange(30).reshape(5, 2, 3) dci = DataChunkIterator(data=a, buffer_size=1, iter_axis=1) count = 0 @@ -76,8 +72,7 @@ def test_numpy_iter_unbuffered_middle_axis(self): self.assertIsNone(dci.recommended_chunk_shape()) def test_numpy_iter_unbuffered_last_axis(self): - """Test DataChunkIterator with numpy data, no buffering, and iterating on the last dimension. - """ + """Test DataChunkIterator with numpy data, no buffering, and iterating on the last dimension.""" a = np.arange(30).reshape(5, 2, 3) dci = DataChunkIterator(data=a, buffer_size=1, iter_axis=2) count = 0 @@ -89,8 +84,7 @@ def test_numpy_iter_unbuffered_last_axis(self): self.assertIsNone(dci.recommended_chunk_shape()) def test_numpy_iter_buffered_first_axis(self): - """Test DataChunkIterator with numpy data, buffering, and iterating on the first dimension. - """ + """Test DataChunkIterator with numpy data, buffering, and iterating on the first dimension.""" a = np.arange(30).reshape(5, 2, 3) dci = DataChunkIterator(data=a, buffer_size=2) count = 0 @@ -105,8 +99,7 @@ def test_numpy_iter_buffered_first_axis(self): self.assertIsNone(dci.recommended_chunk_shape()) def test_numpy_iter_buffered_middle_axis(self): - """Test DataChunkIterator with numpy data, buffering, and iterating on a middle dimension. - """ + """Test DataChunkIterator with numpy data, buffering, and iterating on a middle dimension.""" a = np.arange(45).reshape(5, 3, 3) dci = DataChunkIterator(data=a, buffer_size=2, iter_axis=1) count = 0 @@ -121,8 +114,7 @@ def test_numpy_iter_buffered_middle_axis(self): self.assertIsNone(dci.recommended_chunk_shape()) def test_numpy_iter_buffered_last_axis(self): - """Test DataChunkIterator with numpy data, buffering, and iterating on the last dimension. - """ + """Test DataChunkIterator with numpy data, buffering, and iterating on the last dimension.""" a = np.arange(30).reshape(5, 2, 3) dci = DataChunkIterator(data=a, buffer_size=2, iter_axis=2) count = 0 @@ -182,8 +174,7 @@ def test_standard_iterator_unmatched_buffersized(self): self.assertTupleEqual(dci.recommended_data_shape(), (10,)) # Test before and after iteration def test_multidimensional_list_first_axis(self): - """Test DataChunkIterator with multidimensional list data, no buffering, and iterating on the first dimension. - """ + """Test DataChunkIterator with multidimensional list data, no buffering, and iterating on the first dim.""" a = np.arange(30).reshape(5, 2, 3).tolist() dci = DataChunkIterator(a) self.assertTupleEqual(dci.maxshape, (5, 2, 3)) @@ -197,12 +188,13 @@ def test_multidimensional_list_first_axis(self): self.assertIsNone(dci.recommended_chunk_shape()) def test_multidimensional_list_middle_axis(self): - """Test DataChunkIterator with multidimensional list data, no buffering, and iterating on a middle dimension. - """ + """Test DataChunkIterator with multidimensional list data, no buffering, and iterating on a middle dimension.""" a = np.arange(30).reshape(5, 2, 3).tolist() - warn_msg = ('Iterating over an axis other than the first dimension of list or tuple data ' - 'involves converting the data object to a numpy ndarray, which may incur a computational ' - 'cost.') + warn_msg = ( + "Iterating over an axis other than the first dimension of list or tuple" + " data involves converting the data object to a numpy ndarray, which may" + " incur a computational cost." + ) with self.assertWarnsWith(UserWarning, warn_msg): dci = DataChunkIterator(a, iter_axis=1) self.assertTupleEqual(dci.maxshape, (5, 2, 3)) @@ -216,12 +208,13 @@ def test_multidimensional_list_middle_axis(self): self.assertIsNone(dci.recommended_chunk_shape()) def test_multidimensional_list_last_axis(self): - """Test DataChunkIterator with multidimensional list data, no buffering, and iterating on the last dimension. - """ + """Test DataChunkIterator with multidimensional list data, no buffering, and iterating on the last dimension.""" a = np.arange(30).reshape(5, 2, 3).tolist() - warn_msg = ('Iterating over an axis other than the first dimension of list or tuple data ' - 'involves converting the data object to a numpy ndarray, which may incur a computational ' - 'cost.') + warn_msg = ( + "Iterating over an axis other than the first dimension of list or tuple" + " data involves converting the data object to a numpy ndarray, which may" + " incur a computational cost." + ) with self.assertWarnsWith(UserWarning, warn_msg): dci = DataChunkIterator(a, iter_axis=2) self.assertTupleEqual(dci.maxshape, (5, 2, 3)) @@ -241,7 +234,7 @@ def test_maxshape(self): self.assertEqual(daiter.maxshape, (None, 2, 3)) def test_dtype(self): - a = np.arange(30, dtype='int32').reshape(5, 2, 3) + a = np.arange(30, dtype="int32").reshape(5, 2, 3) aiter = iter(a) daiter = DataChunkIterator.from_iterable(aiter, buffer_size=2) self.assertEqual(daiter.dtype, a.dtype) @@ -273,13 +266,23 @@ def test_sparse_data_buffer_notaligned(self): self.assertListEqual(chunk.data.tolist(), [1, 2]) self.assertEqual(chunk.selection[0], slice(chunk.data[0] - 1, chunk.data[1])) elif count == 1: # [3, None] - self.assertListEqual(chunk.data.tolist(), [3, ]) + self.assertListEqual( + chunk.data.tolist(), + [ + 3, + ], + ) self.assertEqual(chunk.selection[0], slice(chunk.data[0] - 1, chunk.data[0])) elif count == 2: # [8, 9] self.assertListEqual(chunk.data.tolist(), [8, 9]) self.assertEqual(chunk.selection[0], slice(chunk.data[0] - 1, chunk.data[1])) else: # count == 3, [10] - self.assertListEqual(chunk.data.tolist(), [10, ]) + self.assertListEqual( + chunk.data.tolist(), + [ + 10, + ], + ) self.assertEqual(chunk.selection[0], slice(chunk.data[0] - 1, chunk.data[0])) count += 1 self.assertEqual(count, 4) @@ -354,6 +357,7 @@ def my_iter(): count = count + 1 yield val return + dci = DataChunkIterator(data=my_iter(), buffer_size=2) count = 0 for chunk in dci: @@ -375,6 +379,7 @@ def my_iter(): count = count + 1 yield val return + dci = DataChunkIterator(data=my_iter(), buffer_size=2, iter_axis=1) count = 0 for chunk in dci: @@ -396,6 +401,7 @@ def my_iter(): count = count + 1 yield val return + dci = DataChunkIterator(data=my_iter(), buffer_size=2, iter_axis=2) count = 0 for chunk in dci: @@ -417,6 +423,7 @@ def my_iter(): count = count + 1 yield val return + # iterator returns slices of size (5, 2) # because iter_axis is by default 0, these chunks will be placed along the first dimension dci = DataChunkIterator(data=my_iter(), buffer_size=2) @@ -433,7 +440,6 @@ def my_iter(): class DataChunkTests(TestCase): - def setUp(self): pass @@ -449,11 +455,11 @@ def test_len_operator_with_data(self): self.assertEqual(len(temp), 5) def test_dtype(self): - temp = DataChunk(np.arange(10).astype('int')) + temp = DataChunk(np.arange(10).astype("int")) temp_dtype = temp.dtype - self.assertEqual(temp_dtype, np.dtype('int')) + self.assertEqual(temp_dtype, np.dtype("int")) def test_astype(self): temp1 = DataChunk(np.arange(10).reshape(5, 2)) - temp2 = temp1.astype('float32') - self.assertEqual(temp2.dtype, np.dtype('float32')) + temp2 = temp1.astype("float32") + self.assertEqual(temp2.dtype, np.dtype("float32")) diff --git a/tests/unit/utils_test/test_core_DataIO.py b/tests/unit/utils_test/test_core_DataIO.py index 00941cb0e..eebb777b1 100644 --- a/tests/unit/utils_test/test_core_DataIO.py +++ b/tests/unit/utils_test/test_core_DataIO.py @@ -1,13 +1,13 @@ from copy import copy, deepcopy import numpy as np + from hdmf.container import Data from hdmf.data_utils import DataIO from hdmf.testing import TestCase class DataIOTests(TestCase): - def setUp(self): pass @@ -15,13 +15,13 @@ def tearDown(self): pass def test_copy(self): - obj = DataIO(data=[1., 2., 3.]) + obj = DataIO(data=[1.0, 2.0, 3.0]) obj_copy = copy(obj) self.assertNotEqual(id(obj), id(obj_copy)) self.assertEqual(id(obj.data), id(obj_copy.data)) def test_deepcopy(self): - obj = DataIO(data=[1., 2., 3.]) + obj = DataIO(data=[1.0, 2.0, 3.0]) obj_copy = deepcopy(obj) self.assertNotEqual(id(obj), id(obj_copy)) self.assertNotEqual(id(obj.data), id(obj_copy.data)) @@ -41,7 +41,7 @@ def test_set_dataio(self): """ dataio = DataIO() data = np.arange(30).reshape(5, 2, 3) - container = Data('wrapped_data', data) + container = Data("wrapped_data", data) container.set_dataio(dataio) self.assertIs(dataio.data, data) self.assertIs(dataio, container.data) @@ -52,7 +52,7 @@ def test_set_dataio_data_already_set(self): """ dataio = DataIO(data=np.arange(30).reshape(5, 2, 3)) data = np.arange(30).reshape(5, 2, 3) - container = Data('wrapped_data', data) + container = Data("wrapped_data", data) with self.assertRaisesWith(ValueError, "cannot overwrite 'data' on DataIO"): container.set_dataio(dataio) @@ -66,5 +66,8 @@ def test_dataio_options(self): DataIO(data=np.arange(5), shape=(3,)) dataio = DataIO(shape=(3,), dtype=int) - with self.assertRaisesRegex(ValueError, "Setting data when dtype and shape are not None is not supported"): + with self.assertRaisesRegex( + ValueError, + "Setting data when dtype and shape are not None is not supported", + ): dataio.data = np.arange(5) diff --git a/tests/unit/utils_test/test_core_GenericDataChunkIterator.py b/tests/unit/utils_test/test_core_GenericDataChunkIterator.py index 076260b55..282824edb 100644 --- a/tests/unit/utils_test/test_core_GenericDataChunkIterator.py +++ b/tests/unit/utils_test/test_core_GenericDataChunkIterator.py @@ -1,17 +1,18 @@ import unittest -import numpy as np from pathlib import Path -from tempfile import mkdtemp from shutil import rmtree -from typing import Tuple, Iterable +from tempfile import mkdtemp +from typing import Iterable, Tuple import h5py +import numpy as np from hdmf.data_utils import GenericDataChunkIterator from hdmf.testing import TestCase try: import tqdm # noqa: F401 + TQDM_INSTALLED = True except ImportError: TQDM_INSTALLED = False @@ -40,7 +41,9 @@ def __init__(self, array: np.ndarray, **kwargs): def _get_data(self, selection) -> np.ndarray: return self.array[selection] - def _get_maxshape(self) -> Tuple[np.uint64, ...]: # Undesirable return type, but can be handled + def _get_maxshape( + self, + ) -> Tuple[np.uint64, ...]: # Undesirable return type, but can be handled return tuple(np.uint64(x) for x in self.array.shape) def _get_dtype(self) -> np.dtype: @@ -62,12 +65,13 @@ def check_first_data_chunk_call(self, expected_selection, iterator_options): np.testing.assert_array_equal(first_data_chunk, self.test_array[expected_selection]) def check_direct_hdf5_write(self, iterator_options): - iterator = self.TestNumpyArrayDataChunkIterator( - array=self.test_array, **iterator_options - ) + iterator = self.TestNumpyArrayDataChunkIterator(array=self.test_array, **iterator_options) with h5py.File(name=self.test_dir / "test_generic_iterator_array.hdf5", mode="w") as f: dset = f.create_dataset( - name="test", shape=self.test_array.shape, dtype="int16", chunks=iterator.chunk_shape + name="test", + shape=self.test_array.shape, + dtype="int16", + chunks=iterator.chunk_shape, ) for chunk in iterator: dset[chunk.selection] = chunk.data @@ -88,8 +92,8 @@ class TestGenericDataChunkIterator(GenericDataChunkIterator): with self.assertRaisesWith( exc_type=TypeError, exc_msg=( - "Can't instantiate abstract class TestGenericDataChunkIterator with abstract methods " - "_get_data, _get_dtype, _get_maxshape" + "Can't instantiate abstract class TestGenericDataChunkIterator with" + " abstract methods _get_data, _get_dtype, _get_maxshape" ), ): TestGenericDataChunkIterator() @@ -110,26 +114,20 @@ def test_joint_option_assertions(self): chunk_shape = (2001, 384) with self.assertRaisesWith( exc_type=AssertionError, - exc_msg=( - f"Some dimensions of chunk_shape ({chunk_shape}) exceed the " - f"data dimensions ((2000, 384))!" - ), + exc_msg=f"Some dimensions of chunk_shape ({chunk_shape}) exceed the data dimensions ((2000, 384))!", ): - self.TestNumpyArrayDataChunkIterator( - array=self.test_array, chunk_shape=chunk_shape - ) + self.TestNumpyArrayDataChunkIterator(array=self.test_array, chunk_shape=chunk_shape) buffer_shape = (1000, 192) chunk_shape = (100, 384) with self.assertRaisesWith( exc_type=AssertionError, - exc_msg=( - f"Some dimensions of chunk_shape ({chunk_shape}) exceed the " - f"buffer shape ({buffer_shape})!" - ), + exc_msg=f"Some dimensions of chunk_shape ({chunk_shape}) exceed the buffer shape ({buffer_shape})!", ): self.TestNumpyArrayDataChunkIterator( - array=self.test_array, buffer_shape=buffer_shape, chunk_shape=chunk_shape + array=self.test_array, + buffer_shape=buffer_shape, + chunk_shape=chunk_shape, ) buffer_shape = (1000, 192) @@ -137,19 +135,21 @@ def test_joint_option_assertions(self): with self.assertRaisesWith( exc_type=AssertionError, exc_msg=( - f"Some dimensions of chunk_shape ({chunk_shape}) do not evenly divide the " - f"buffer shape ({buffer_shape})!" + f"Some dimensions of chunk_shape ({chunk_shape}) do not evenly divide" + f" the buffer shape ({buffer_shape})!" ), ): self.TestNumpyArrayDataChunkIterator( - array=self.test_array, buffer_shape=buffer_shape, chunk_shape=chunk_shape + array=self.test_array, + buffer_shape=buffer_shape, + chunk_shape=chunk_shape, ) def test_buffer_option_assertion_negative_buffer_gb(self): buffer_gb = -1 with self.assertRaisesWith( exc_type=AssertionError, - exc_msg=f"buffer_gb ({buffer_gb}) must be greater than zero!" + exc_msg=f"buffer_gb ({buffer_gb}) must be greater than zero!", ): self.TestNumpyArrayDataChunkIterator(array=self.test_array, buffer_gb=buffer_gb) @@ -160,7 +160,7 @@ def test_buffer_option_assertion_exceed_maxshape(self): exc_msg=( f"Some dimensions of buffer_shape ({buffer_shape}) exceed the data " f"dimensions ({self.test_array.shape})!" - ) + ), ): self.TestNumpyArrayDataChunkIterator(array=self.test_array, buffer_shape=buffer_shape) @@ -168,7 +168,7 @@ def test_buffer_option_assertion_negative_shape(self): buffer_shape = (-1, 384) with self.assertRaisesWith( exc_type=AssertionError, - exc_msg=f"Some dimensions of buffer_shape ({buffer_shape}) are less than zero!" + exc_msg=f"Some dimensions of buffer_shape ({buffer_shape}) are less than zero!", ): self.TestNumpyArrayDataChunkIterator(array=self.test_array, buffer_shape=buffer_shape) @@ -176,7 +176,7 @@ def test_chunk_option_assertion_negative_chunk_mb(self): chunk_mb = -1 with self.assertRaisesWith( exc_type=AssertionError, - exc_msg=f"chunk_mb ({chunk_mb}) must be greater than zero!" + exc_msg=f"chunk_mb ({chunk_mb}) must be greater than zero!", ): self.TestNumpyArrayDataChunkIterator(array=self.test_array, chunk_mb=chunk_mb) @@ -184,7 +184,7 @@ def test_chunk_option_assertion_negative_shape(self): chunk_shape = (-1, 384) with self.assertRaisesWith( exc_type=AssertionError, - exc_msg=f"Some dimensions of chunk_shape ({chunk_shape}) are less than zero!" + exc_msg=f"Some dimensions of chunk_shape ({chunk_shape}) are less than zero!", ): self.TestNumpyArrayDataChunkIterator(array=self.test_array, chunk_shape=chunk_shape) @@ -192,7 +192,7 @@ def test_chunk_option_assertion_negative_shape(self): def test_progress_bar_assertion(self): with self.assertWarnsWith( warn_type=UserWarning, - exc_msg="Option 'total' in 'progress_bar_options' is not allowed to be over-written! Ignoring." + exc_msg="Option 'total' in 'progress_bar_options' is not allowed to be over-written! Ignoring.", ): _ = self.TestNumpyArrayDataChunkIterator( array=self.test_array, @@ -239,7 +239,7 @@ def test_manual_chunk_shape_attribute_int_type(self): self.check_all_of_iterable_is_python_int( iterable=self.TestNumpyArrayDataChunkIterator( array=self.test_array, - chunk_shape=(np.uint64(100), np.uint64(2)) + chunk_shape=(np.uint64(100), np.uint64(2)), ).chunk_shape ) @@ -260,14 +260,17 @@ def test_num_buffers(self): expected_num_buffers = 9 test = self.TestNumpyArrayDataChunkIterator( - array=self.test_array, buffer_shape=buffer_shape, chunk_shape=chunk_shape + array=self.test_array, + buffer_shape=buffer_shape, + chunk_shape=chunk_shape, ) self.assertEqual(first=test.num_buffers, second=expected_num_buffers) def test_numpy_array_chunk_iterator(self): iterator_options = dict() self.check_first_data_chunk_call( - expected_selection=(slice(0, 2000), slice(0, 384)), iterator_options=iterator_options + expected_selection=(slice(0, 2000), slice(0, 384)), + iterator_options=iterator_options, ) self.check_direct_hdf5_write(iterator_options=iterator_options) @@ -285,12 +288,7 @@ def test_buffer_gb_option(self): resulting_buffer_shape = (1580, 316) iterator_options = dict(buffer_gb=0.0005) self.check_first_data_chunk_call( - expected_selection=tuple( - [ - slice(0, buffer_shape_axis) - for buffer_shape_axis in resulting_buffer_shape - ] - ), + expected_selection=tuple([slice(0, buffer_shape_axis) for buffer_shape_axis in resulting_buffer_shape]), iterator_options=iterator_options, ) self.check_direct_hdf5_write(iterator_options=iterator_options) @@ -300,12 +298,7 @@ def test_buffer_gb_option(self): for buffer_gb_input_dtype_pass in [2, 2.0]: iterator_options = dict(buffer_gb=2) self.check_first_data_chunk_call( - expected_selection=tuple( - [ - slice(0, buffer_shape_axis) - for buffer_shape_axis in resulting_buffer_shape - ] - ), + expected_selection=tuple([slice(0, buffer_shape_axis) for buffer_shape_axis in resulting_buffer_shape]), iterator_options=iterator_options, ) self.check_direct_hdf5_write(iterator_options=iterator_options) @@ -329,14 +322,24 @@ def test_chunk_mb_option_larger_than_total_size(self): def test_chunk_mb_option_while_condition(self): """Test to evoke while condition of default shaping method.""" expected_chunk_shape = (2, 79, 79) - special_array = np.random.randint(low=-(2 ** 15), high=2 ** 15 - 1, size=(2, 2000, 2000), dtype="int16") + special_array = np.random.randint( + low=-(2**15), + high=2**15 - 1, + size=(2, 2000, 2000), + dtype="int16", + ) iterator = self.TestNumpyArrayDataChunkIterator(array=special_array) self.assertEqual(iterator.chunk_shape, expected_chunk_shape) def test_chunk_mb_option_while_condition_unit_maxshape_axis(self): """Test to evoke while condition of default shaping method.""" expected_chunk_shape = (1, 79, 79) - special_array = np.random.randint(low=-(2 ** 15), high=2 ** 15 - 1, size=(1, 2000, 2000), dtype="int16") + special_array = np.random.randint( + low=-(2**15), + high=2**15 - 1, + size=(1, 2000, 2000), + dtype="int16", + ) iterator = self.TestNumpyArrayDataChunkIterator(array=special_array) self.assertEqual(iterator.chunk_shape, expected_chunk_shape) @@ -346,7 +349,9 @@ def test_progress_bar(self): desc = "Testing progress bar..." with open(file=out_text_file, mode="w") as file: iterator = self.TestNumpyArrayDataChunkIterator( - array=self.test_array, display_progress=True, progress_bar_options=dict(file=file, desc=desc) + array=self.test_array, + display_progress=True, + progress_bar_options=dict(file=file, desc=desc), ) j = 0 for buffer in iterator: @@ -365,8 +370,9 @@ def test_progress_bar_no_options(self): def test_tqdm_not_installed(self): with self.assertWarnsWith( warn_type=UserWarning, - exc_msg=("You must install tqdm to use the progress bar feature (pip install tqdm)! " - "Progress bar is disabled.") + exc_msg=( + "You must install tqdm to use the progress bar feature (pip install tqdm)! Progress bar is disabled." + ), ): dci = self.TestNumpyArrayDataChunkIterator( array=self.test_array, diff --git a/tests/unit/utils_test/test_core_ShapeValidator.py b/tests/unit/utils_test/test_core_ShapeValidator.py index bde86a3b3..b60220d75 100644 --- a/tests/unit/utils_test/test_core_ShapeValidator.py +++ b/tests/unit/utils_test/test_core_ShapeValidator.py @@ -1,11 +1,11 @@ import numpy as np + from hdmf.common.table import DynamicTable, DynamicTableRegion, VectorData -from hdmf.data_utils import ShapeValidatorResult, DataChunkIterator, assertEqualShape +from hdmf.data_utils import DataChunkIterator, ShapeValidatorResult, assertEqualShape from hdmf.testing import TestCase class ShapeValidatorTests(TestCase): - def setUp(self): pass @@ -32,7 +32,7 @@ def test_array_dimensions_mismatch(self): d2 = np.arange(10).reshape(5, 2) res = assertEqualShape(d1, d2) self.assertFalse(res.result) - self.assertEqual(res.error, 'AXIS_LEN_ERROR') + self.assertEqual(res.error, "AXIS_LEN_ERROR") self.assertTupleEqual(res.ignored, ()) self.assertTupleEqual(res.unmatched, ((0, 0), (1, 1))) self.assertTupleEqual(res.shape1, (2, 5)) @@ -46,7 +46,7 @@ def test_array_unequal_number_of_dimensions(self): d2 = np.arange(20).reshape(5, 2, 2) res = assertEqualShape(d1, d2) self.assertFalse(res.result) - self.assertEqual(res.error, 'NUM_AXES_ERROR') + self.assertEqual(res.error, "NUM_AXES_ERROR") self.assertTupleEqual(res.ignored, ()) self.assertTupleEqual(res.unmatched, ()) self.assertTupleEqual(res.shape1, (2, 5)) @@ -102,7 +102,7 @@ def test_array_axis_index_out_of_bounds_single_axis(self): d2 = np.arange(20).reshape(5, 2, 2) res = assertEqualShape(d1, d2, 4, 1) self.assertFalse(res.result) - self.assertEqual(res.error, 'AXIS_OUT_OF_BOUNDS') + self.assertEqual(res.error, "AXIS_OUT_OF_BOUNDS") self.assertTupleEqual(res.ignored, ()) self.assertTupleEqual(res.unmatched, ()) self.assertTupleEqual(res.shape1, (2, 5)) @@ -116,7 +116,7 @@ def test_array_axis_index_out_of_bounds_mutilple_axis(self): d2 = np.arange(20).reshape(5, 2, 2) res = assertEqualShape(d1, d2, [0, 1], [5, 0]) self.assertFalse(res.result) - self.assertEqual(res.error, 'AXIS_OUT_OF_BOUNDS') + self.assertEqual(res.error, "AXIS_OUT_OF_BOUNDS") self.assertTupleEqual(res.ignored, ()) self.assertTupleEqual(res.unmatched, ()) self.assertTupleEqual(res.shape1, (2, 5)) @@ -158,7 +158,7 @@ def test_DataChunkIterator_error_on_undetermined_axis(self): d2 = DataChunkIterator(data=np.arange(10).reshape(2, 5)) res = assertEqualShape(d1, d2, ignore_undetermined=False) self.assertFalse(res.result) - self.assertEqual(res.error, 'AXIS_LEN_ERROR') + self.assertEqual(res.error, "AXIS_LEN_ERROR") self.assertTupleEqual(res.ignored, ()) self.assertTupleEqual(res.unmatched, ((0, 0),)) self.assertTupleEqual(res.shape1, (None, 5)) @@ -169,29 +169,29 @@ def test_DataChunkIterator_error_on_undetermined_axis(self): def test_DynamicTableRegion_shape_validation(self): # Create a test DynamicTable dt_spec = [ - {'name': 'foo', 'description': 'foo column'}, - {'name': 'bar', 'description': 'bar column'}, - {'name': 'baz', 'description': 'baz column'}, + {"name": "foo", "description": "foo column"}, + {"name": "bar", "description": "bar column"}, + {"name": "baz", "description": "baz column"}, ] dt_data = [ [1, 2, 3, 4, 5], [10.0, 20.0, 30.0, 40.0, 50.0], - ['cat', 'dog', 'bird', 'fish', 'lizard'] + ["cat", "dog", "bird", "fish", "lizard"], ] - columns = [ - VectorData(name=s['name'], description=s['description'], data=d) - for s, d in zip(dt_spec, dt_data) - ] - dt = DynamicTable(name="with_columns_and_data", description="a test table", columns=columns) + columns = [VectorData(name=s["name"], description=s["description"], data=d) for s, d in zip(dt_spec, dt_data)] + dt = DynamicTable( + name="with_columns_and_data", + description="a test table", + columns=columns, + ) # Create test DynamicTableRegion - dtr = DynamicTableRegion(name='dtr', data=[1, 2, 2], description='desc', table=dt) + dtr = DynamicTableRegion(name="dtr", data=[1, 2, 2], description="desc", table=dt) # Confirm that the shapes match res = assertEqualShape(dtr, np.arange(9).reshape(3, 3)) self.assertTrue(res.result) class ShapeValidatorResultTests(TestCase): - def setUp(self): pass @@ -200,18 +200,25 @@ def tearDown(self): def test_default_message(self): temp = ShapeValidatorResult() - temp.error = 'AXIS_LEN_ERROR' + temp.error = "AXIS_LEN_ERROR" self.assertEqual(temp.default_message, ShapeValidatorResult.SHAPE_ERROR[temp.error]) def test_set_error_to_illegal_type(self): temp = ShapeValidatorResult() with self.assertRaises(ValueError): - temp.error = 'MY_ILLEGAL_ERROR_TYPE' + temp.error = "MY_ILLEGAL_ERROR_TYPE" def test_ensure_use_of_tuples_during_asignment(self): temp = ShapeValidatorResult() temp_d = [1, 2] - temp_cases = ['shape1', 'shape2', 'axes1', 'axes2', 'ignored', 'unmatched'] + temp_cases = [ + "shape1", + "shape2", + "axes1", + "axes2", + "ignored", + "unmatched", + ] for var in temp_cases: setattr(temp, var, temp_d) - self.assertIsInstance(getattr(temp, var), tuple, var) + self.assertIsInstance(getattr(temp, var), tuple, var) diff --git a/tests/unit/utils_test/test_docval.py b/tests/unit/utils_test/test_docval.py index d0ea934f7..3ad80907a 100644 --- a/tests/unit/utils_test/test_docval.py +++ b/tests/unit/utils_test/test_docval.py @@ -1,67 +1,106 @@ import numpy as np + from hdmf.testing import TestCase -from hdmf.utils import (docval, fmt_docval_args, get_docval, getargs, popargs, AllowPositional, get_docval_macro, - docval_macro, popargs_to_dict, call_docval_func) +from hdmf.utils import ( + AllowPositional, + call_docval_func, + docval, + docval_macro, + fmt_docval_args, + get_docval, + get_docval_macro, + getargs, + popargs, + popargs_to_dict, +) class MyTestClass(object): - - @docval({'name': 'arg1', 'type': str, 'doc': 'argument1 is a str'}) + @docval({"name": "arg1", "type": str, "doc": "argument1 is a str"}) def basic_add(self, **kwargs): return kwargs - @docval({'name': 'arg1', 'type': str, 'doc': 'argument1 is a str'}, - {'name': 'arg2', 'type': int, 'doc': 'argument2 is a int'}) + @docval( + {"name": "arg1", "type": str, "doc": "argument1 is a str"}, + {"name": "arg2", "type": int, "doc": "argument2 is a int"}, + ) def basic_add2(self, **kwargs): return kwargs - @docval({'name': 'arg1', 'type': str, 'doc': 'argument1 is a str'}, - {'name': 'arg2', 'type': 'int', 'doc': 'argument2 is a int'}, - {'name': 'arg3', 'type': bool, 'doc': 'argument3 is a bool. it defaults to False', 'default': False}) + @docval( + {"name": "arg1", "type": str, "doc": "argument1 is a str"}, + {"name": "arg2", "type": "int", "doc": "argument2 is a int"}, + {"name": "arg3", "type": bool, "doc": "argument3 is a bool. it defaults to False", "default": False}, + ) def basic_add2_kw(self, **kwargs): return kwargs - @docval({'name': 'arg1', 'type': str, 'doc': 'argument1 is a str', 'default': 'a'}, - {'name': 'arg2', 'type': int, 'doc': 'argument2 is a int', 'default': 1}) + @docval( + {"name": "arg1", "type": str, "doc": "argument1 is a str", "default": "a"}, + {"name": "arg2", "type": int, "doc": "argument2 is a int", "default": 1}, + ) def basic_only_kw(self, **kwargs): return kwargs - @docval({'name': 'arg1', 'type': str, 'doc': 'argument1 is a str'}, - {'name': 'arg2', 'type': 'int', 'doc': 'argument2 is a int'}, - {'name': 'arg3', 'type': bool, 'doc': 'argument3 is a bool. it defaults to False', 'default': False}, - allow_extra=True) + @docval( + {"name": "arg1", "type": str, "doc": "argument1 is a str"}, + {"name": "arg2", "type": "int", "doc": "argument2 is a int"}, + {"name": "arg3", "type": bool, "doc": "argument3 is a bool. it defaults to False", "default": False}, + allow_extra=True, + ) def basic_add2_kw_allow_extra(self, **kwargs): return kwargs class MyTestSubclass(MyTestClass): - - @docval({'name': 'arg1', 'type': str, 'doc': 'argument1 is a str'}, - {'name': 'arg2', 'type': int, 'doc': 'argument2 is a int'}) + @docval( + {"name": "arg1", "type": str, "doc": "argument1 is a str"}, + {"name": "arg2", "type": int, "doc": "argument2 is a int"}, + ) def basic_add(self, **kwargs): return kwargs - @docval({'name': 'arg1', 'type': str, 'doc': 'argument1 is a str'}, - {'name': 'arg2', 'type': int, 'doc': 'argument2 is a int'}, - {'name': 'arg3', 'type': bool, 'doc': 'argument3 is a bool. it defaults to False', 'default': False}, - {'name': 'arg4', 'type': str, 'doc': 'argument4 is a str'}, - {'name': 'arg5', 'type': 'float', 'doc': 'argument5 is a float'}, - {'name': 'arg6', 'type': bool, 'doc': 'argument6 is a bool. it defaults to None', 'default': None}) + @docval( + {"name": "arg1", "type": str, "doc": "argument1 is a str"}, + {"name": "arg2", "type": int, "doc": "argument2 is a int"}, + {"name": "arg3", "type": bool, "doc": "argument3 is a bool. it defaults to False", "default": False}, + {"name": "arg4", "type": str, "doc": "argument4 is a str"}, + {"name": "arg5", "type": "float", "doc": "argument5 is a float"}, + {"name": "arg6", "type": bool, "doc": "argument6 is a bool. it defaults to None", "default": None}, + ) def basic_add2_kw(self, **kwargs): return kwargs class MyChainClass(MyTestClass): - - @docval({'name': 'arg1', 'type': (str, 'MyChainClass'), 'doc': 'arg1 is a string or MyChainClass'}, - {'name': 'arg2', 'type': ('array_data', 'MyChainClass'), - 'doc': 'arg2 is array data or MyChainClass. it defaults to None', 'default': None}, - {'name': 'arg3', 'type': ('array_data', 'MyChainClass'), 'doc': 'arg3 is array data or MyChainClass', - 'shape': (None, 2)}, - {'name': 'arg4', 'type': ('array_data', 'MyChainClass'), - 'doc': 'arg3 is array data or MyChainClass. it defaults to None.', 'shape': (None, 2), 'default': None}) + @docval( + { + "name": "arg1", + "type": (str, "MyChainClass"), + "doc": "arg1 is a string or MyChainClass", + }, + { + "name": "arg2", + "type": ("array_data", "MyChainClass"), + "doc": "arg2 is array data or MyChainClass. it defaults to None", + "default": None, + }, + { + "name": "arg3", + "type": ("array_data", "MyChainClass"), + "doc": "arg3 is array data or MyChainClass", + "shape": (None, 2), + }, + { + "name": "arg4", + "type": ("array_data", "MyChainClass"), + "doc": "arg3 is array data or MyChainClass. it defaults to None.", + "shape": (None, 2), + "default": None, + }, + ) def __init__(self, **kwargs): - self._arg1, self._arg2, self._arg3, self._arg4 = popargs('arg1', 'arg2', 'arg3', 'arg4', kwargs) + self._arg1, self._arg2, self._arg3, self._arg4 = popargs("arg1", "arg2", "arg3", "arg4", kwargs) @property def arg1(self): @@ -101,24 +140,28 @@ def arg4(self, val): class TestDocValidator(TestCase): - def setUp(self): self.test_obj = MyTestClass() self.test_obj_sub = MyTestSubclass() def test_bad_type(self): - exp_msg = (r"docval for arg1: error parsing argument type: argtype must be a type, " - r"a str, a list, a tuple, or None - got ") + exp_msg = ( + r"docval for arg1: error parsing argument type: argtype must be a type, " + r"a str, a list, a tuple, or None - got " + ) with self.assertRaisesRegex(Exception, exp_msg): - @docval({'name': 'arg1', 'type': {'a': 1}, 'doc': 'this is a bad type'}) + + @docval({"name": "arg1", "type": {"a": 1}, "doc": "this is a bad type"}) def method(self, **kwargs): pass + method(self, arg1=1234560) def test_bad_shape(self): - @docval({'name': 'arg1', 'type': 'array_data', 'doc': 'this is a bad shape', 'shape': (None, 2)}) + @docval({"name": "arg1", "type": "array_data", "doc": "this is a bad shape", "shape": (None, 2)}) def method(self, **kwargs): pass + with self.assertRaises(ValueError): method(self, arg1=[[1]]) with self.assertRaises(ValueError): @@ -127,8 +170,7 @@ def method(self, **kwargs): method(self, arg1=[[1, 1]]) def test_multi_shape(self): - @docval({'name': 'arg1', 'type': 'array_data', 'doc': 'this is a bad shape', - 'shape': ((None,), (None, 2))}) + @docval({"name": "arg1", "type": "array_data", "doc": "this is a bad shape", "shape": ((None,), (None, 2))}) def method1(self, **kwargs): pass @@ -138,31 +180,34 @@ def method1(self, **kwargs): method1(self, arg1=[[1, 1, 1]]) fmt_docval_warning_msg = ( - "fmt_docval_args will be deprecated in a future version of HDMF. Instead of using fmt_docval_args, " - "call the function directly with the kwargs. Please note that fmt_docval_args " - "removes all arguments not accepted by the function's docval, so if you are passing kwargs that " - "includes extra arguments and the function's docval does not allow extra arguments (allow_extra=True " - "is set), then you will need to pop the extra arguments out of kwargs before calling the function." + "fmt_docval_args will be deprecated in a future version of HDMF. Instead of" + " using fmt_docval_args, call the function directly with the kwargs. Please" + " note that fmt_docval_args removes all arguments not accepted by the" + " function's docval, so if you are passing kwargs that includes extra arguments" + " and the function's docval does not allow extra arguments (allow_extra=True is" + " set), then you will need to pop the extra arguments out of kwargs before" + " calling the function." ) def test_fmt_docval_args(self): - """ Test that fmt_docval_args parses the args and strips extra args """ + """Test that fmt_docval_args parses the args and strips extra args""" test_kwargs = { - 'arg1': 'a string', - 'arg2': 1, - 'arg3': True, - 'hello': 'abc', - 'list': ['abc', 1, 2, 3] + "arg1": "a string", + "arg2": 1, + "arg3": True, + "hello": "abc", + "list": ["abc", 1, 2, 3], } with self.assertWarnsWith(PendingDeprecationWarning, self.fmt_docval_warning_msg): rec_args, rec_kwargs = fmt_docval_args(self.test_obj.basic_add2_kw, test_kwargs) - exp_args = ['a string', 1] + exp_args = ["a string", 1] self.assertListEqual(rec_args, exp_args) - exp_kwargs = {'arg3': True} + exp_kwargs = {"arg3": True} self.assertDictEqual(rec_kwargs, exp_kwargs) def test_fmt_docval_args_no_docval(self): - """ Test that fmt_docval_args raises an error when run on function without docval """ + """Test that fmt_docval_args raises an error when run on function without docval""" + def method1(self, **kwargs): pass @@ -171,378 +216,415 @@ def method1(self, **kwargs): fmt_docval_args(method1, {}) def test_fmt_docval_args_allow_extra(self): - """ Test that fmt_docval_args works """ + """Test that fmt_docval_args works""" test_kwargs = { - 'arg1': 'a string', - 'arg2': 1, - 'arg3': True, - 'hello': 'abc', - 'list': ['abc', 1, 2, 3] + "arg1": "a string", + "arg2": 1, + "arg3": True, + "hello": "abc", + "list": ["abc", 1, 2, 3], } with self.assertWarnsWith(PendingDeprecationWarning, self.fmt_docval_warning_msg): rec_args, rec_kwargs = fmt_docval_args(self.test_obj.basic_add2_kw_allow_extra, test_kwargs) - exp_args = ['a string', 1] + exp_args = ["a string", 1] self.assertListEqual(rec_args, exp_args) - exp_kwargs = {'arg3': True, 'hello': 'abc', 'list': ['abc', 1, 2, 3]} + exp_kwargs = {"arg3": True, "hello": "abc", "list": ["abc", 1, 2, 3]} self.assertDictEqual(rec_kwargs, exp_kwargs) def test_call_docval_func(self): """Test that call_docval_func strips extra args and calls the function.""" test_kwargs = { - 'arg1': 'a string', - 'arg2': 1, - 'arg3': True, - 'hello': 'abc', - 'list': ['abc', 1, 2, 3] + "arg1": "a string", + "arg2": 1, + "arg3": True, + "hello": "abc", + "list": ["abc", 1, 2, 3], } msg = ( - "call_docval_func will be deprecated in a future version of HDMF. Instead of using call_docval_func, " - "call the function directly with the kwargs. Please note that call_docval_func " - "removes all arguments not accepted by the function's docval, so if you are passing kwargs that " - "includes extra arguments and the function's docval does not allow extra arguments (allow_extra=True " - "is set), then you will need to pop the extra arguments out of kwargs before calling the function." + "call_docval_func will be deprecated in a future version of HDMF. Instead" + " of using call_docval_func, call the function directly with the kwargs." + " Please note that call_docval_func removes all arguments not accepted by" + " the function's docval, so if you are passing kwargs that includes extra" + " arguments and the function's docval does not allow extra arguments" + " (allow_extra=True is set), then you will need to pop the extra arguments" + " out of kwargs before calling the function." ) with self.assertWarnsWith(PendingDeprecationWarning, msg): ret_kwargs = call_docval_func(self.test_obj.basic_add2_kw, test_kwargs) - exp_kwargs = { - 'arg1': 'a string', - 'arg2': 1, - 'arg3': True - } + exp_kwargs = {"arg1": "a string", "arg2": 1, "arg3": True} self.assertDictEqual(ret_kwargs, exp_kwargs) def test_docval_add(self): """Test that docval works with a single positional - argument + argument """ - kwargs = self.test_obj.basic_add('a string') - self.assertDictEqual(kwargs, {'arg1': 'a string'}) + kwargs = self.test_obj.basic_add("a string") + self.assertDictEqual(kwargs, {"arg1": "a string"}) def test_docval_add_kw(self): """Test that docval works with a single positional - argument passed as key-value + argument passed as key-value """ - kwargs = self.test_obj.basic_add(arg1='a string') - self.assertDictEqual(kwargs, {'arg1': 'a string'}) + kwargs = self.test_obj.basic_add(arg1="a string") + self.assertDictEqual(kwargs, {"arg1": "a string"}) def test_docval_add_missing_args(self): """Test that docval catches missing argument - with a single positional argument + with a single positional argument """ with self.assertRaisesWith(TypeError, "MyTestClass.basic_add: missing argument 'arg1'"): self.test_obj.basic_add() def test_docval_add2(self): """Test that docval works with two positional - arguments + arguments """ - kwargs = self.test_obj.basic_add2('a string', 100) - self.assertDictEqual(kwargs, {'arg1': 'a string', 'arg2': 100}) + kwargs = self.test_obj.basic_add2("a string", 100) + self.assertDictEqual(kwargs, {"arg1": "a string", "arg2": 100}) def test_docval_add2_w_unicode(self): """Test that docval works with two positional - arguments + arguments """ - kwargs = self.test_obj.basic_add2(u'a string', 100) - self.assertDictEqual(kwargs, {'arg1': u'a string', 'arg2': 100}) + kwargs = self.test_obj.basic_add2("a string", 100) + self.assertDictEqual(kwargs, {"arg1": "a string", "arg2": 100}) def test_docval_add2_kw_default(self): """Test that docval works with two positional - arguments and a keyword argument when using - default keyword argument value + arguments and a keyword argument when using + default keyword argument value """ - kwargs = self.test_obj.basic_add2_kw('a string', 100) - self.assertDictEqual(kwargs, {'arg1': 'a string', 'arg2': 100, 'arg3': False}) + kwargs = self.test_obj.basic_add2_kw("a string", 100) + self.assertDictEqual(kwargs, {"arg1": "a string", "arg2": 100, "arg3": False}) def test_docval_add2_pos_as_kw(self): """Test that docval works with two positional - arguments and a keyword argument when using - default keyword argument value, but pass - positional arguments by key-value + arguments and a keyword argument when using + default keyword argument value, but pass + positional arguments by key-value """ - kwargs = self.test_obj.basic_add2_kw(arg1='a string', arg2=100) - self.assertDictEqual(kwargs, {'arg1': 'a string', 'arg2': 100, 'arg3': False}) + kwargs = self.test_obj.basic_add2_kw(arg1="a string", arg2=100) + self.assertDictEqual(kwargs, {"arg1": "a string", "arg2": 100, "arg3": False}) def test_docval_add2_kw_kw_syntax(self): """Test that docval works with two positional - arguments and a keyword argument when specifying - keyword argument value with keyword syntax + arguments and a keyword argument when specifying + keyword argument value with keyword syntax """ - kwargs = self.test_obj.basic_add2_kw('a string', 100, arg3=True) - self.assertDictEqual(kwargs, {'arg1': 'a string', 'arg2': 100, 'arg3': True}) + kwargs = self.test_obj.basic_add2_kw("a string", 100, arg3=True) + self.assertDictEqual(kwargs, {"arg1": "a string", "arg2": 100, "arg3": True}) def test_docval_add2_kw_all_kw_syntax(self): """Test that docval works with two positional - arguments and a keyword argument when specifying - all arguments by key-value + arguments and a keyword argument when specifying + all arguments by key-value """ - kwargs = self.test_obj.basic_add2_kw(arg1='a string', arg2=100, arg3=True) - self.assertDictEqual(kwargs, {'arg1': 'a string', 'arg2': 100, 'arg3': True}) + kwargs = self.test_obj.basic_add2_kw(arg1="a string", arg2=100, arg3=True) + self.assertDictEqual(kwargs, {"arg1": "a string", "arg2": 100, "arg3": True}) def test_docval_add2_kw_pos_syntax(self): """Test that docval works with two positional - arguments and a keyword argument when specifying - keyword argument value with positional syntax + arguments and a keyword argument when specifying + keyword argument value with positional syntax """ - kwargs = self.test_obj.basic_add2_kw('a string', 100, True) - self.assertDictEqual(kwargs, {'arg1': 'a string', 'arg2': 100, 'arg3': True}) + kwargs = self.test_obj.basic_add2_kw("a string", 100, True) + self.assertDictEqual(kwargs, {"arg1": "a string", "arg2": 100, "arg3": True}) def test_docval_add2_kw_pos_syntax_missing_args(self): """Test that docval catches incorrect type with two positional - arguments and a keyword argument when specifying - keyword argument value with positional syntax + arguments and a keyword argument when specifying + keyword argument value with positional syntax """ msg = "MyTestClass.basic_add2_kw: incorrect type for 'arg2' (got 'str', expected 'int')" with self.assertRaisesWith(TypeError, msg): - self.test_obj.basic_add2_kw('a string', 'bad string') + self.test_obj.basic_add2_kw("a string", "bad string") def test_docval_add_sub(self): """Test that docval works with a two positional arguments, - where the second is specified by the subclass implementation + where the second is specified by the subclass implementation """ - kwargs = self.test_obj_sub.basic_add('a string', 100) - expected = {'arg1': 'a string', 'arg2': 100} + kwargs = self.test_obj_sub.basic_add("a string", 100) + expected = {"arg1": "a string", "arg2": 100} self.assertDictEqual(kwargs, expected) def test_docval_add2_kw_default_sub(self): """Test that docval works with a four positional arguments and - two keyword arguments, where two positional and one keyword - argument is specified in both the parent and subclass implementations + two keyword arguments, where two positional and one keyword + argument is specified in both the parent and subclass implementations """ - kwargs = self.test_obj_sub.basic_add2_kw('a string', 100, 'another string', 200.0) - expected = {'arg1': 'a string', 'arg2': 100, - 'arg4': 'another string', 'arg5': 200.0, - 'arg3': False, 'arg6': None} + kwargs = self.test_obj_sub.basic_add2_kw("a string", 100, "another string", 200.0) + expected = { + "arg1": "a string", + "arg2": 100, + "arg4": "another string", + "arg5": 200.0, + "arg3": False, + "arg6": None, + } self.assertDictEqual(kwargs, expected) def test_docval_add2_kw_default_sub_missing_args(self): """Test that docval catches missing arguments with a four positional arguments - and two keyword arguments, where two positional and one keyword - argument is specified in both the parent and subclass implementations, - when using default values for keyword arguments + and two keyword arguments, where two positional and one keyword + argument is specified in both the parent and subclass implementations, + when using default values for keyword arguments """ with self.assertRaisesWith(TypeError, "MyTestSubclass.basic_add2_kw: missing argument 'arg5'"): - self.test_obj_sub.basic_add2_kw('a string', 100, 'another string') + self.test_obj_sub.basic_add2_kw("a string", 100, "another string") def test_docval_add2_kw_kwsyntax_sub(self): """Test that docval works when called with a four positional - arguments and two keyword arguments, where two positional - and one keyword argument is specified in both the parent - and subclass implementations + arguments and two keyword arguments, where two positional + and one keyword argument is specified in both the parent + and subclass implementations """ - kwargs = self.test_obj_sub.basic_add2_kw('a string', 100, 'another string', 200.0, arg6=True) - expected = {'arg1': 'a string', 'arg2': 100, - 'arg4': 'another string', 'arg5': 200.0, - 'arg3': False, 'arg6': True} + kwargs = self.test_obj_sub.basic_add2_kw("a string", 100, "another string", 200.0, arg6=True) + expected = { + "arg1": "a string", + "arg2": 100, + "arg4": "another string", + "arg5": 200.0, + "arg3": False, + "arg6": True, + } self.assertDictEqual(kwargs, expected) def test_docval_add2_kw_kwsyntax_sub_missing_args(self): """Test that docval catches missing arguments when called with a four positional - arguments and two keyword arguments, where two positional and one keyword - argument is specified in both the parent and subclass implementations + arguments and two keyword arguments, where two positional and one keyword + argument is specified in both the parent and subclass implementations """ with self.assertRaisesWith(TypeError, "MyTestSubclass.basic_add2_kw: missing argument 'arg5'"): - self.test_obj_sub.basic_add2_kw('a string', 100, 'another string', arg6=True) + self.test_obj_sub.basic_add2_kw("a string", 100, "another string", arg6=True) def test_docval_add2_kw_kwsyntax_sub_nonetype_arg(self): """Test that docval catches NoneType when called with a four positional - arguments and two keyword arguments, where two positional and one keyword - argument is specified in both the parent and subclass implementations + arguments and two keyword arguments, where two positional and one keyword + argument is specified in both the parent and subclass implementations """ msg = "MyTestSubclass.basic_add2_kw: None is not allowed for 'arg5' (expected 'float', not None)" with self.assertRaisesWith(TypeError, msg): - self.test_obj_sub.basic_add2_kw('a string', 100, 'another string', None, arg6=True) + self.test_obj_sub.basic_add2_kw("a string", 100, "another string", None, arg6=True) def test_only_kw_no_args(self): """Test that docval parses arguments when only keyword - arguments exist, and no arguments are specified + arguments exist, and no arguments are specified """ kwargs = self.test_obj.basic_only_kw() - self.assertDictEqual(kwargs, {'arg1': 'a', 'arg2': 1}) + self.assertDictEqual(kwargs, {"arg1": "a", "arg2": 1}) def test_only_kw_arg1_no_arg2(self): """Test that docval parses arguments when only keyword - arguments exist, and only first argument is specified - as key-value + arguments exist, and only first argument is specified + as key-value """ - kwargs = self.test_obj.basic_only_kw(arg1='b') - self.assertDictEqual(kwargs, {'arg1': 'b', 'arg2': 1}) + kwargs = self.test_obj.basic_only_kw(arg1="b") + self.assertDictEqual(kwargs, {"arg1": "b", "arg2": 1}) def test_only_kw_arg1_pos_no_arg2(self): """Test that docval parses arguments when only keyword - arguments exist, and only first argument is specified - as positional argument + arguments exist, and only first argument is specified + as positional argument """ - kwargs = self.test_obj.basic_only_kw('b') - self.assertDictEqual(kwargs, {'arg1': 'b', 'arg2': 1}) + kwargs = self.test_obj.basic_only_kw("b") + self.assertDictEqual(kwargs, {"arg1": "b", "arg2": 1}) def test_only_kw_arg2_no_arg1(self): """Test that docval parses arguments when only keyword - arguments exist, and only second argument is specified - as key-value + arguments exist, and only second argument is specified + as key-value """ kwargs = self.test_obj.basic_only_kw(arg2=2) - self.assertDictEqual(kwargs, {'arg1': 'a', 'arg2': 2}) + self.assertDictEqual(kwargs, {"arg1": "a", "arg2": 2}) def test_only_kw_arg1_arg2(self): """Test that docval parses arguments when only keyword - arguments exist, and both arguments are specified - as key-value + arguments exist, and both arguments are specified + as key-value """ - kwargs = self.test_obj.basic_only_kw(arg1='b', arg2=2) - self.assertDictEqual(kwargs, {'arg1': 'b', 'arg2': 2}) + kwargs = self.test_obj.basic_only_kw(arg1="b", arg2=2) + self.assertDictEqual(kwargs, {"arg1": "b", "arg2": 2}) def test_only_kw_arg1_arg2_pos(self): """Test that docval parses arguments when only keyword - arguments exist, and both arguments are specified - as positional arguments + arguments exist, and both arguments are specified + as positional arguments """ - kwargs = self.test_obj.basic_only_kw('b', 2) - self.assertDictEqual(kwargs, {'arg1': 'b', 'arg2': 2}) + kwargs = self.test_obj.basic_only_kw("b", 2) + self.assertDictEqual(kwargs, {"arg1": "b", "arg2": 2}) def test_extra_kwarg(self): """Test that docval parses arguments when only keyword - arguments exist, and both arguments are specified - as positional arguments + arguments exist, and both arguments are specified + as positional arguments """ with self.assertRaises(TypeError): - self.test_obj.basic_add2_kw('a string', 100, bar=1000) + self.test_obj.basic_add2_kw("a string", 100, bar=1000) def test_extra_args_pos_only(self): """Test that docval raises an error if too many positional - arguments are specified + arguments are specified """ - msg = ("MyTestClass.basic_add2_kw: Expected at most 3 arguments ['arg1', 'arg2', 'arg3'], got 4: 4 positional " - "and 0 keyword []") + msg = ( + "MyTestClass.basic_add2_kw: Expected at most 3 arguments ['arg1', 'arg2'," + " 'arg3'], got 4: 4 positional and 0 keyword []" + ) with self.assertRaisesWith(TypeError, msg): - self.test_obj.basic_add2_kw('a string', 100, True, 'extra') + self.test_obj.basic_add2_kw("a string", 100, True, "extra") def test_extra_args_pos_kw(self): """Test that docval raises an error if too many positional - arguments are specified and a keyword arg is specified + arguments are specified and a keyword arg is specified """ - msg = ("MyTestClass.basic_add2_kw: Expected at most 3 arguments ['arg1', 'arg2', 'arg3'], got 4: 3 positional " - "and 1 keyword ['arg3']") + msg = ( + "MyTestClass.basic_add2_kw: Expected at most 3 arguments ['arg1', 'arg2'," + " 'arg3'], got 4: 3 positional and 1 keyword ['arg3']" + ) with self.assertRaisesWith(TypeError, msg): - self.test_obj.basic_add2_kw('a string', 'extra', 100, arg3=True) + self.test_obj.basic_add2_kw("a string", "extra", 100, arg3=True) def test_extra_kwargs_pos_kw(self): """Test that docval raises an error if extra keyword - arguments are specified + arguments are specified """ - msg = ("MyTestClass.basic_add2_kw: Expected at most 3 arguments ['arg1', 'arg2', 'arg3'], got 4: 2 positional " - "and 2 keyword ['arg3', 'extra']") + msg = ( + "MyTestClass.basic_add2_kw: Expected at most 3 arguments ['arg1', 'arg2'," + " 'arg3'], got 4: 2 positional and 2 keyword ['arg3', 'extra']" + ) with self.assertRaisesWith(TypeError, msg): - self.test_obj.basic_add2_kw('a string', 100, extra='extra', arg3=True) + self.test_obj.basic_add2_kw("a string", 100, extra="extra", arg3=True) def test_extra_args_pos_only_ok(self): """Test that docval raises an error if too many positional - arguments are specified even if allow_extra is True + arguments are specified even if allow_extra is True """ - msg = ("MyTestClass.basic_add2_kw_allow_extra: Expected at most 3 arguments ['arg1', 'arg2', 'arg3'], got " - "4 positional") + msg = ( + "MyTestClass.basic_add2_kw_allow_extra: Expected at most 3 arguments" + " ['arg1', 'arg2', 'arg3'], got 4 positional" + ) with self.assertRaisesWith(TypeError, msg): - self.test_obj.basic_add2_kw_allow_extra('a string', 100, True, 'extra', extra='extra') + self.test_obj.basic_add2_kw_allow_extra("a string", 100, True, "extra", extra="extra") def test_extra_args_pos_kw_ok(self): """Test that docval does not raise an error if too many - keyword arguments are specified and allow_extra is True + keyword arguments are specified and allow_extra is True """ - kwargs = self.test_obj.basic_add2_kw_allow_extra('a string', 100, True, extra='extra') - self.assertDictEqual(kwargs, {'arg1': 'a string', 'arg2': 100, 'arg3': True, 'extra': 'extra'}) + kwargs = self.test_obj.basic_add2_kw_allow_extra("a string", 100, True, extra="extra") + self.assertDictEqual( + kwargs, + {"arg1": "a string", "arg2": 100, "arg3": True, "extra": "extra"}, + ) def test_dup_kw(self): """Test that docval raises an error if a keyword argument - captures a positional argument before all positional - arguments have been resolved + captures a positional argument before all positional + arguments have been resolved """ - with self.assertRaisesWith(TypeError, "MyTestClass.basic_add2_kw: got multiple values for argument 'arg1'"): - self.test_obj.basic_add2_kw('a string', 100, arg1='extra') + with self.assertRaisesWith( + TypeError, + "MyTestClass.basic_add2_kw: got multiple values for argument 'arg1'", + ): + self.test_obj.basic_add2_kw("a string", 100, arg1="extra") def test_extra_args_dup_kw(self): """Test that docval raises an error if a keyword argument - captures a positional argument before all positional - arguments have been resolved and allow_extra is True + captures a positional argument before all positional + arguments have been resolved and allow_extra is True """ msg = "MyTestClass.basic_add2_kw_allow_extra: got multiple values for argument 'arg1'" with self.assertRaisesWith(TypeError, msg): - self.test_obj.basic_add2_kw_allow_extra('a string', 100, True, arg1='extra') + self.test_obj.basic_add2_kw_allow_extra("a string", 100, True, arg1="extra") def test_unsupported_docval_term(self): """Test that docval does not allow setting of arguments - marked as unsupported + marked as unsupported """ msg = "docval for arg1: keys ['unsupported'] are not supported by docval" with self.assertRaisesWith(Exception, msg): - @docval({'name': 'arg1', 'type': 'array_data', 'doc': 'this is a bad shape', 'unsupported': 'hi!'}) + + @docval({"name": "arg1", "type": "array_data", "doc": "this is a bad shape", "unsupported": "hi!"}) def method(self, **kwargs): pass def test_catch_dup_names(self): - """Test that docval does not allow duplicate argument names - """ - @docval({'name': 'arg1', 'type': 'array_data', 'doc': 'this is a bad shape'}, - {'name': 'arg1', 'type': 'array_data', 'doc': 'this is a bad shape2'}) + """Test that docval does not allow duplicate argument names""" + + @docval( + {"name": "arg1", "type": "array_data", "doc": "this is a bad shape"}, + {"name": "arg1", "type": "array_data", "doc": "this is a bad shape2"}, + ) def method(self, **kwargs): pass + msg = "TestDocValidator.test_catch_dup_names..method: The following names are duplicated: ['arg1']" with self.assertRaisesWith(ValueError, msg): method(self, arg1=[1]) def test_get_docval_all(self): - """Test that get_docval returns a tuple of the docval arguments - """ + """Test that get_docval returns a tuple of the docval arguments""" args = get_docval(self.test_obj.basic_add2) - self.assertTupleEqual(args, ({'name': 'arg1', 'type': str, 'doc': 'argument1 is a str'}, - {'name': 'arg2', 'type': int, 'doc': 'argument2 is a int'})) + self.assertTupleEqual( + args, + ( + {"name": "arg1", "type": str, "doc": "argument1 is a str"}, + {"name": "arg2", "type": int, "doc": "argument2 is a int"}, + ), + ) def test_get_docval_one_arg(self): - """Test that get_docval returns the matching docval argument - """ - arg = get_docval(self.test_obj.basic_add2, 'arg2') - self.assertTupleEqual(arg, ({'name': 'arg2', 'type': int, 'doc': 'argument2 is a int'},)) + """Test that get_docval returns the matching docval argument""" + arg = get_docval(self.test_obj.basic_add2, "arg2") + self.assertTupleEqual(arg, ({"name": "arg2", "type": int, "doc": "argument2 is a int"},)) def test_get_docval_two_args(self): - """Test that get_docval returns the matching docval arguments in order - """ - args = get_docval(self.test_obj.basic_add2, 'arg2', 'arg1') - self.assertTupleEqual(args, ({'name': 'arg2', 'type': int, 'doc': 'argument2 is a int'}, - {'name': 'arg1', 'type': str, 'doc': 'argument1 is a str'})) + """Test that get_docval returns the matching docval arguments in order""" + args = get_docval(self.test_obj.basic_add2, "arg2", "arg1") + self.assertTupleEqual( + args, + ( + {"name": "arg2", "type": int, "doc": "argument2 is a int"}, + {"name": "arg1", "type": str, "doc": "argument1 is a str"}, + ), + ) def test_get_docval_missing_arg(self): - """Test that get_docval throws error if the matching docval argument is not found - """ - with self.assertRaisesWith(ValueError, "Function basic_add2 does not have docval argument 'arg3'"): - get_docval(self.test_obj.basic_add2, 'arg3') + """Test that get_docval throws error if the matching docval argument is not found""" + with self.assertRaisesWith( + ValueError, + "Function basic_add2 does not have docval argument 'arg3'", + ): + get_docval(self.test_obj.basic_add2, "arg3") def test_get_docval_missing_args(self): - """Test that get_docval throws error if the matching docval arguments is not found - """ - with self.assertRaisesWith(ValueError, "Function basic_add2 does not have docval argument 'arg3'"): - get_docval(self.test_obj.basic_add2, 'arg3', 'arg4') + """Test that get_docval throws error if the matching docval arguments is not found""" + with self.assertRaisesWith( + ValueError, + "Function basic_add2 does not have docval argument 'arg3'", + ): + get_docval(self.test_obj.basic_add2, "arg3", "arg4") def test_get_docval_missing_arg_of_many_ok(self): - """Test that get_docval throws error if the matching docval arguments is not found - """ - with self.assertRaisesWith(ValueError, "Function basic_add2 does not have docval argument 'arg3'"): - get_docval(self.test_obj.basic_add2, 'arg2', 'arg3') + """Test that get_docval throws error if the matching docval arguments is not found""" + with self.assertRaisesWith( + ValueError, + "Function basic_add2 does not have docval argument 'arg3'", + ): + get_docval(self.test_obj.basic_add2, "arg2", "arg3") def test_get_docval_none(self): - """Test that get_docval returns an empty tuple if there is no docval - """ + """Test that get_docval returns an empty tuple if there is no docval""" args = get_docval(self.test_obj.__init__) self.assertTupleEqual(args, tuple()) def test_get_docval_none_arg(self): - """Test that get_docval throws error if there is no docval and an argument name is passed - """ - with self.assertRaisesWith(ValueError, 'Function __init__ has no docval arguments'): - get_docval(self.test_obj.__init__, 'arg3') + """Test that get_docval throws error if there is no docval and an argument name is passed""" + with self.assertRaisesWith(ValueError, "Function __init__ has no docval arguments"): + get_docval(self.test_obj.__init__, "arg3") def test_bool_type(self): - @docval({'name': 'arg1', 'type': bool, 'doc': 'this is a bool'}) + @docval({"name": "arg1", "type": bool, "doc": "this is a bool"}) def method(self, **kwargs): - return popargs('arg1', kwargs) + return popargs("arg1", kwargs) res = method(self, arg1=True) self.assertEqual(res, True) @@ -553,9 +635,9 @@ def method(self, **kwargs): self.assertIsInstance(res, np.bool_) def test_bool_string_type(self): - @docval({'name': 'arg1', 'type': 'bool', 'doc': 'this is a bool'}) + @docval({"name": "arg1", "type": "bool", "doc": "this is a bool"}) def method(self, **kwargs): - return popargs('arg1', kwargs) + return popargs("arg1", kwargs) res = method(self, arg1=True) self.assertEqual(res, True) @@ -567,29 +649,35 @@ def method(self, **kwargs): def test_uint_type(self): """Test that docval type specification of np.uint32 works as expected.""" - @docval({'name': 'arg1', 'type': np.uint32, 'doc': 'this is a uint'}) + + @docval({"name": "arg1", "type": np.uint32, "doc": "this is a uint"}) def method(self, **kwargs): - return popargs('arg1', kwargs) + return popargs("arg1", kwargs) res = method(self, arg1=np.uint32(1)) self.assertEqual(res, np.uint32(1)) self.assertIsInstance(res, np.uint32) - msg = ("TestDocValidator.test_uint_type..method: incorrect type for 'arg1' (got 'uint8', expected " - "'uint32')") + msg = ( + "TestDocValidator.test_uint_type..method: incorrect type for 'arg1'" + " (got 'uint8', expected 'uint32')" + ) with self.assertRaisesWith(TypeError, msg): method(self, arg1=np.uint8(1)) - msg = ("TestDocValidator.test_uint_type..method: incorrect type for 'arg1' (got 'uint64', expected " - "'uint32')") + msg = ( + "TestDocValidator.test_uint_type..method: incorrect type for 'arg1'" + " (got 'uint64', expected 'uint32')" + ) with self.assertRaisesWith(TypeError, msg): method(self, arg1=np.uint64(1)) def test_uint_string_type(self): """Test that docval type specification of string 'uint' matches np.uint of all available precisions.""" - @docval({'name': 'arg1', 'type': 'uint', 'doc': 'this is a uint'}) + + @docval({"name": "arg1", "type": "uint", "doc": "this is a uint"}) def method(self, **kwargs): - return popargs('arg1', kwargs) + return popargs("arg1", kwargs) res = method(self, arg1=np.uint(1)) self.assertEqual(res, np.uint(1)) @@ -612,9 +700,12 @@ def method(self, **kwargs): self.assertIsInstance(res, np.uint64) def test_allow_positional_warn(self): - @docval({'name': 'arg1', 'type': bool, 'doc': 'this is a bool'}, allow_positional=AllowPositional.WARNING) + @docval( + {"name": "arg1", "type": bool, "doc": "this is a bool"}, + allow_positional=AllowPositional.WARNING, + ) def method(self, **kwargs): - return popargs('arg1', kwargs) + return popargs("arg1", kwargs) # check that supplying a keyword arg is OK res = method(self, arg1=True) @@ -622,16 +713,22 @@ def method(self, **kwargs): self.assertIsInstance(res, bool) # check that supplying a positional arg raises a warning - msg = ('TestDocValidator.test_allow_positional_warn..method: ' - 'Using positional arguments for this method is discouraged and will be deprecated in a future major ' - 'release. Please use keyword arguments to ensure future compatibility.') + msg = ( + "TestDocValidator.test_allow_positional_warn..method: Using" + " positional arguments for this method is discouraged and will be" + " deprecated in a future major release. Please use keyword arguments to" + " ensure future compatibility." + ) with self.assertWarnsWith(FutureWarning, msg): method(self, True) def test_allow_positional_error(self): - @docval({'name': 'arg1', 'type': bool, 'doc': 'this is a bool'}, allow_positional=AllowPositional.ERROR) + @docval( + {"name": "arg1", "type": bool, "doc": "this is a bool"}, + allow_positional=AllowPositional.ERROR, + ) def method(self, **kwargs): - return popargs('arg1', kwargs) + return popargs("arg1", kwargs) # check that supplying a keyword arg is OK res = method(self, arg1=True) @@ -639,21 +736,33 @@ def method(self, **kwargs): self.assertIsInstance(res, bool) # check that supplying a positional arg raises an error - msg = ('TestDocValidator.test_allow_positional_error..method: ' - 'Only keyword arguments (e.g., func(argname=value, ...)) are allowed for this method.') + msg = ( + "TestDocValidator.test_allow_positional_error..method: Only keyword" + " arguments (e.g., func(argname=value, ...)) are allowed for this method." + ) with self.assertRaisesWith(SyntaxError, msg): method(self, True) def test_allow_none_false(self): """Test that docval with allow_none=True and non-None default value works""" - @docval({'name': 'arg1', 'type': bool, 'doc': 'this is a bool or None with a default', 'default': True, - 'allow_none': False}) + + @docval( + { + "name": "arg1", + "type": bool, + "doc": "this is a bool or None with a default", + "default": True, + "allow_none": False, + } + ) def method(self, **kwargs): - return popargs('arg1', kwargs) + return popargs("arg1", kwargs) # if provided, None is not allowed - msg = ("TestDocValidator.test_allow_none_false..method: incorrect type for 'arg1' " - "(got 'NoneType', expected 'bool')") + msg = ( + "TestDocValidator.test_allow_none_false..method: incorrect type for" + " 'arg1' (got 'NoneType', expected 'bool')" + ) with self.assertRaisesWith(TypeError, msg): res = method(self, arg1=None) @@ -663,10 +772,18 @@ def method(self, **kwargs): def test_allow_none(self): """Test that docval with allow_none=True and non-None default value works""" - @docval({'name': 'arg1', 'type': bool, 'doc': 'this is a bool or None with a default', 'default': True, - 'allow_none': True}) + + @docval( + { + "name": "arg1", + "type": bool, + "doc": "this is a bool or None with a default", + "default": True, + "allow_none": True, + } + ) def method(self, **kwargs): - return popargs('arg1', kwargs) + return popargs("arg1", kwargs) # if provided, None is allowed res = method(self, arg1=None) @@ -678,10 +795,18 @@ def method(self, **kwargs): def test_allow_none_redundant(self): """Test that docval with allow_none=True and default=None works""" - @docval({'name': 'arg1', 'type': bool, 'doc': 'this is a bool or None with a default', 'default': None, - 'allow_none': True}) + + @docval( + { + "name": "arg1", + "type": bool, + "doc": "this is a bool or None with a default", + "default": None, + "allow_none": True, + } + ) def method(self, **kwargs): - return popargs('arg1', kwargs) + return popargs("arg1", kwargs) # if provided, None is allowed res = method(self, arg1=None) @@ -693,151 +818,182 @@ def method(self, **kwargs): def test_allow_none_no_default(self): """Test that docval with allow_none=True and no default raises an error""" - msg = ("docval for arg1: allow_none=True can only be set if a default value is provided.") + msg = "docval for arg1: allow_none=True can only be set if a default value is provided." with self.assertRaisesWith(Exception, msg): - @docval({'name': 'arg1', 'type': bool, 'doc': 'this is a bool or None with a default', 'allow_none': True}) + + @docval( + { + "name": "arg1", + "type": bool, + "doc": "this is a bool or None with a default", + "allow_none": True, + } + ) def method(self, **kwargs): - return popargs('arg1', kwargs) + return popargs("arg1", kwargs) def test_enum_str(self): """Test that the basic usage of an enum check on strings works""" - @docval({'name': 'arg1', 'type': str, 'doc': 'an arg', 'enum': ['a', 'b']}) # also use enum: list + + @docval({"name": "arg1", "type": str, "doc": "an arg", "enum": ["a", "b"]}) # also use enum: list def method(self, **kwargs): - return popargs('arg1', kwargs) + return popargs("arg1", kwargs) - self.assertEqual(method(self, 'a'), 'a') - self.assertEqual(method(self, 'b'), 'b') + self.assertEqual(method(self, "a"), "a") + self.assertEqual(method(self, "b"), "b") - msg = ("TestDocValidator.test_enum_str..method: " - "forbidden value for 'arg1' (got 'c', expected ['a', 'b'])") + msg = ( + "TestDocValidator.test_enum_str..method: forbidden value for 'arg1' (got 'c', expected ['a', 'b'])" + ) with self.assertRaisesWith(ValueError, msg): - method(self, 'c') + method(self, "c") def test_enum_int(self): """Test that the basic usage of an enum check on ints works""" - @docval({'name': 'arg1', 'type': int, 'doc': 'an arg', 'enum': (1, 2)}) + + @docval({"name": "arg1", "type": int, "doc": "an arg", "enum": (1, 2)}) def method(self, **kwargs): - return popargs('arg1', kwargs) + return popargs("arg1", kwargs) self.assertEqual(method(self, 1), 1) self.assertEqual(method(self, 2), 2) - msg = ("TestDocValidator.test_enum_int..method: " - "forbidden value for 'arg1' (got 3, expected (1, 2))") + msg = "TestDocValidator.test_enum_int..method: forbidden value for 'arg1' (got 3, expected (1, 2))" with self.assertRaisesWith(ValueError, msg): method(self, 3) def test_enum_uint(self): """Test that the basic usage of an enum check on uints works""" - @docval({'name': 'arg1', 'type': np.uint, 'doc': 'an arg', 'enum': (np.uint(1), np.uint(2))}) + + @docval({"name": "arg1", "type": np.uint, "doc": "an arg", "enum": (np.uint(1), np.uint(2))}) def method(self, **kwargs): - return popargs('arg1', kwargs) + return popargs("arg1", kwargs) self.assertEqual(method(self, np.uint(1)), np.uint(1)) self.assertEqual(method(self, np.uint(2)), np.uint(2)) - msg = ("TestDocValidator.test_enum_uint..method: " - "forbidden value for 'arg1' (got 3, expected (1, 2))") + msg = "TestDocValidator.test_enum_uint..method: forbidden value for 'arg1' (got 3, expected (1, 2))" with self.assertRaisesWith(ValueError, msg): method(self, np.uint(3)) def test_enum_float(self): """Test that the basic usage of an enum check on floats works""" - @docval({'name': 'arg1', 'type': float, 'doc': 'an arg', 'enum': (3.14, )}) + + @docval({"name": "arg1", "type": float, "doc": "an arg", "enum": (3.14,)}) def method(self, **kwargs): - return popargs('arg1', kwargs) + return popargs("arg1", kwargs) self.assertEqual(method(self, 3.14), 3.14) - msg = ("TestDocValidator.test_enum_float..method: " - "forbidden value for 'arg1' (got 3.0, expected (3.14,))") + msg = "TestDocValidator.test_enum_float..method: forbidden value for 'arg1' (got 3.0, expected (3.14,))" with self.assertRaisesWith(ValueError, msg): - method(self, 3.) + method(self, 3.0) def test_enum_bool_mixed(self): """Test that the basic usage of an enum check on a tuple of bool, int, float, and string works""" - @docval({'name': 'arg1', 'type': (bool, int, float, str, np.uint), 'doc': 'an arg', - 'enum': (True, 1, 1.0, 'true', np.uint(1))}) + + @docval( + { + "name": "arg1", + "type": (bool, int, float, str, np.uint), + "doc": "an arg", + "enum": (True, 1, 1.0, "true", np.uint(1)), + } + ) def method(self, **kwargs): - return popargs('arg1', kwargs) + return popargs("arg1", kwargs) self.assertEqual(method(self, True), True) self.assertEqual(method(self, 1), 1) self.assertEqual(method(self, 1.0), 1.0) - self.assertEqual(method(self, 'true'), 'true') + self.assertEqual(method(self, "true"), "true") self.assertEqual(method(self, np.uint(1)), np.uint(1)) - msg = ("TestDocValidator.test_enum_bool_mixed..method: " - "forbidden value for 'arg1' (got 0, expected (True, 1, 1.0, 'true', 1))") + msg = ( + "TestDocValidator.test_enum_bool_mixed..method: " + "forbidden value for 'arg1' (got 0, expected (True, 1, 1.0, 'true', 1))" + ) with self.assertRaisesWith(ValueError, msg): method(self, 0) def test_enum_bad_type(self): """Test that docval with an enum check where the arg type includes an invalid enum type fails""" - msg = ("docval for arg1: enum checking cannot be used with arg type (, , " - ", , )") + msg = ( + "docval for arg1: enum checking cannot be used with arg type (, , , , )" + ) with self.assertRaisesWith(Exception, msg): - @docval({'name': 'arg1', 'type': (bool, int, str, np.float64, object), 'doc': 'an arg', 'enum': (1, 2)}) + + @docval({"name": "arg1", "type": (bool, int, str, np.float64, object), "doc": "an arg", "enum": (1, 2)}) def method(self, **kwargs): - return popargs('arg1', kwargs) + return popargs("arg1", kwargs) def test_enum_none_type(self): """Test that the basic usage of an enum check on None works""" - msg = ("docval for arg1: enum checking cannot be used with arg type None") + msg = "docval for arg1: enum checking cannot be used with arg type None" with self.assertRaisesWith(Exception, msg): - @docval({'name': 'arg1', 'type': None, 'doc': 'an arg', 'enum': (True, 1, 'true')}) + + @docval({"name": "arg1", "type": None, "doc": "an arg", "enum": (True, 1, "true")}) def method(self, **kwargs): pass def test_enum_single_allowed(self): """Test that docval with an enum check on a single value fails""" - msg = ("docval for arg1: enum value must be a list or tuple (received )") + msg = "docval for arg1: enum value must be a list or tuple (received )" with self.assertRaisesWith(Exception, msg): - @docval({'name': 'arg1', 'type': str, 'doc': 'an arg', 'enum': 'only one value'}) + + @docval({"name": "arg1", "type": str, "doc": "an arg", "enum": "only one value"}) def method(self, **kwargs): pass def test_enum_str_default(self): """Test that docval with an enum check on strings and a default value works""" - @docval({'name': 'arg1', 'type': str, 'doc': 'an arg', 'default': 'a', 'enum': ['a', 'b']}) + + @docval({"name": "arg1", "type": str, "doc": "an arg", "default": "a", "enum": ["a", "b"]}) def method(self, **kwargs): - return popargs('arg1', kwargs) + return popargs("arg1", kwargs) - self.assertEqual(method(self), 'a') + self.assertEqual(method(self), "a") - msg = ("TestDocValidator.test_enum_str_default..method: " - "forbidden value for 'arg1' (got 'c', expected ['a', 'b'])") + msg = ( + "TestDocValidator.test_enum_str_default..method: " + "forbidden value for 'arg1' (got 'c', expected ['a', 'b'])" + ) with self.assertRaisesWith(ValueError, msg): - method(self, 'c') + method(self, "c") def test_enum_str_none_default(self): """Test that docval with an enum check on strings and a None default value works""" - @docval({'name': 'arg1', 'type': str, 'doc': 'an arg', 'default': None, 'enum': ['a', 'b']}) + + @docval({"name": "arg1", "type": str, "doc": "an arg", "default": None, "enum": ["a", "b"]}) def method(self, **kwargs): - return popargs('arg1', kwargs) + return popargs("arg1", kwargs) self.assertIsNone(method(self)) def test_enum_forbidden_values(self): """Test that docval with enum values that include a forbidden type fails""" - msg = ("docval for arg1: enum values are of types not allowed by arg type " - "(got [, ], expected )") + msg = ( + "docval for arg1: enum values are of types not allowed by arg type " + "(got [, ], expected )" + ) with self.assertRaisesWith(Exception, msg): - @docval({'name': 'arg1', 'type': bool, 'doc': 'an arg', 'enum': (True, [])}) + + @docval({"name": "arg1", "type": bool, "doc": "an arg", "enum": (True, [])}) def method(self, **kwargs): pass class TestDocValidatorChain(TestCase): - def setUp(self): - self.obj1 = MyChainClass('base', [[1, 2], [3, 4], [5, 6]], [[10, 20]]) + self.obj1 = MyChainClass("base", [[1, 2], [3, 4], [5, 6]], [[10, 20]]) # note that self.obj1.arg3 == [[1, 2], [3, 4], [5, 6]] def test_type_arg(self): """Test that passing an object for an argument that allows a specific type works""" obj2 = MyChainClass(self.obj1, [[10, 20], [30, 40], [50, 60]], [[10, 20]]) - self.assertEqual(obj2.arg1, 'base') + self.assertEqual(obj2.arg1, "base") def test_type_arg_wrong_type(self): """Test that passing an object for an argument that does not match a specific type raises an error""" @@ -876,8 +1032,10 @@ def test_shape_other_unpack(self): obj2 = MyChainClass(self.obj1, [[10, 20], [30, 40], [50, 60]], [[10, 20]]) obj2.arg3 = object() - err_msg = (r"cannot check shape of object '' for argument 'arg3' " - r"\(expected shape '\(None, 2\)'\)") + err_msg = ( + r"cannot check shape of object '' for argument 'arg3' " + r"\(expected shape '\(None, 2\)'\)" + ) with self.assertRaisesRegex(ValueError, err_msg): MyChainClass(self.obj1, obj2, [[100, 200]]) @@ -918,8 +1076,10 @@ def test_shape_other_unpack_default(self): # shape after an object is initialized obj2.arg4 = object() - err_msg = (r"cannot check shape of object '' for argument 'arg4' " - r"\(expected shape '\(None, 2\)'\)") + err_msg = ( + r"cannot check shape of object '' for argument 'arg4' " + r"\(expected shape '\(None, 2\)'\)" + ) with self.assertRaisesRegex(ValueError, err_msg): MyChainClass(self.obj1, [[100, 200], [300, 400], [500, 600]], arg4=obj2) @@ -928,192 +1088,197 @@ class TestGetargs(TestCase): """Test the getargs function and its error conditions.""" def test_one_arg_first(self): - kwargs = {'a': 1, 'b': None} + kwargs = {"a": 1, "b": None} expected_kwargs = kwargs.copy() - res = getargs('a', kwargs) + res = getargs("a", kwargs) self.assertEqual(res, 1) self.assertDictEqual(kwargs, expected_kwargs) def test_one_arg_second(self): - kwargs = {'a': 1, 'b': None} + kwargs = {"a": 1, "b": None} expected_kwargs = kwargs.copy() - res = getargs('b', kwargs) + res = getargs("b", kwargs) self.assertEqual(res, None) self.assertDictEqual(kwargs, expected_kwargs) def test_many_args_get_some(self): - kwargs = {'a': 1, 'b': None, 'c': 3} + kwargs = {"a": 1, "b": None, "c": 3} expected_kwargs = kwargs.copy() - res = getargs('a', 'c', kwargs) + res = getargs("a", "c", kwargs) self.assertListEqual(res, [1, 3]) self.assertDictEqual(kwargs, expected_kwargs) def test_many_args_get_all(self): - kwargs = {'a': 1, 'b': None, 'c': 3} + kwargs = {"a": 1, "b": None, "c": 3} expected_kwargs = kwargs.copy() - res = getargs('a', 'b', 'c', kwargs) + res = getargs("a", "b", "c", kwargs) self.assertListEqual(res, [1, None, 3]) self.assertDictEqual(kwargs, expected_kwargs) def test_many_args_reverse(self): - kwargs = {'a': 1, 'b': None, 'c': 3} + kwargs = {"a": 1, "b": None, "c": 3} expected_kwargs = kwargs.copy() - res = getargs('c', 'b', 'a', kwargs) + res = getargs("c", "b", "a", kwargs) self.assertListEqual(res, [3, None, 1]) self.assertDictEqual(kwargs, expected_kwargs) def test_many_args_unpack(self): - kwargs = {'a': 1, 'b': None, 'c': 3} + kwargs = {"a": 1, "b": None, "c": 3} expected_kwargs = kwargs.copy() - res1, res2, res3 = getargs('a', 'b', 'c', kwargs) + res1, res2, res3 = getargs("a", "b", "c", kwargs) self.assertEqual(res1, 1) self.assertEqual(res2, None) self.assertEqual(res3, 3) self.assertDictEqual(kwargs, expected_kwargs) def test_too_few_args(self): - kwargs = {'a': 1, 'b': None} - msg = 'Must supply at least one key and a dict' + kwargs = {"a": 1, "b": None} + msg = "Must supply at least one key and a dict" with self.assertRaisesWith(ValueError, msg): getargs(kwargs) def test_last_arg_not_dict(self): - kwargs = {'a': 1, 'b': None} - msg = 'Last argument must be a dict' + kwargs = {"a": 1, "b": None} + msg = "Last argument must be a dict" with self.assertRaisesWith(ValueError, msg): - getargs(kwargs, 'a') + getargs(kwargs, "a") def test_arg_not_found_one_arg(self): - kwargs = {'a': 1, 'b': None} + kwargs = {"a": 1, "b": None} msg = "Argument not found in dict: 'c'" with self.assertRaisesWith(ValueError, msg): - getargs('c', kwargs) + getargs("c", kwargs) def test_arg_not_found_many_args(self): - kwargs = {'a': 1, 'b': None} + kwargs = {"a": 1, "b": None} msg = "Argument not found in dict: 'c'" with self.assertRaisesWith(ValueError, msg): - getargs('a', 'c', kwargs) + getargs("a", "c", kwargs) class TestPopargs(TestCase): """Test the popargs function and its error conditions.""" def test_one_arg_first(self): - kwargs = {'a': 1, 'b': None} - res = popargs('a', kwargs) + kwargs = {"a": 1, "b": None} + res = popargs("a", kwargs) self.assertEqual(res, 1) - self.assertDictEqual(kwargs, {'b': None}) + self.assertDictEqual(kwargs, {"b": None}) def test_one_arg_second(self): - kwargs = {'a': 1, 'b': None} - res = popargs('b', kwargs) + kwargs = {"a": 1, "b": None} + res = popargs("b", kwargs) self.assertEqual(res, None) - self.assertDictEqual(kwargs, {'a': 1}) + self.assertDictEqual(kwargs, {"a": 1}) def test_many_args_pop_some(self): - kwargs = {'a': 1, 'b': None, 'c': 3} - res = popargs('a', 'c', kwargs) + kwargs = {"a": 1, "b": None, "c": 3} + res = popargs("a", "c", kwargs) self.assertListEqual(res, [1, 3]) - self.assertDictEqual(kwargs, {'b': None}) + self.assertDictEqual(kwargs, {"b": None}) def test_many_args_pop_all(self): - kwargs = {'a': 1, 'b': None, 'c': 3} - res = popargs('a', 'b', 'c', kwargs) + kwargs = {"a": 1, "b": None, "c": 3} + res = popargs("a", "b", "c", kwargs) self.assertListEqual(res, [1, None, 3]) self.assertDictEqual(kwargs, {}) def test_many_args_reverse(self): - kwargs = {'a': 1, 'b': None, 'c': 3} - res = popargs('c', 'b', 'a', kwargs) + kwargs = {"a": 1, "b": None, "c": 3} + res = popargs("c", "b", "a", kwargs) self.assertListEqual(res, [3, None, 1]) self.assertDictEqual(kwargs, {}) def test_many_args_unpack(self): - kwargs = {'a': 1, 'b': None, 'c': 3} - res1, res2, res3 = popargs('a', 'b', 'c', kwargs) + kwargs = {"a": 1, "b": None, "c": 3} + res1, res2, res3 = popargs("a", "b", "c", kwargs) self.assertEqual(res1, 1) self.assertEqual(res2, None) self.assertEqual(res3, 3) self.assertDictEqual(kwargs, {}) def test_too_few_args(self): - kwargs = {'a': 1, 'b': None} - msg = 'Must supply at least one key and a dict' + kwargs = {"a": 1, "b": None} + msg = "Must supply at least one key and a dict" with self.assertRaisesWith(ValueError, msg): popargs(kwargs) def test_last_arg_not_dict(self): - kwargs = {'a': 1, 'b': None} - msg = 'Last argument must be a dict' + kwargs = {"a": 1, "b": None} + msg = "Last argument must be a dict" with self.assertRaisesWith(ValueError, msg): - popargs(kwargs, 'a') + popargs(kwargs, "a") def test_arg_not_found_one_arg(self): - kwargs = {'a': 1, 'b': None} + kwargs = {"a": 1, "b": None} msg = "Argument not found in dict: 'c'" with self.assertRaisesWith(ValueError, msg): - popargs('c', kwargs) + popargs("c", kwargs) def test_arg_not_found_many_args(self): - kwargs = {'a': 1, 'b': None} + kwargs = {"a": 1, "b": None} msg = "Argument not found in dict: 'c'" with self.assertRaisesWith(ValueError, msg): - popargs('a', 'c', kwargs) + popargs("a", "c", kwargs) class TestPopargsToDict(TestCase): """Test the popargs_to_dict function and its error conditions.""" def test_one_arg_first(self): - kwargs = {'a': 1, 'b': None} - res = popargs_to_dict(['a'], kwargs) - self.assertDictEqual(res, {'a': 1}) - self.assertDictEqual(kwargs, {'b': None}) + kwargs = {"a": 1, "b": None} + res = popargs_to_dict(["a"], kwargs) + self.assertDictEqual(res, {"a": 1}) + self.assertDictEqual(kwargs, {"b": None}) def test_one_arg_second(self): - kwargs = {'a': 1, 'b': None} - res = popargs_to_dict(['b'], kwargs) - self.assertDictEqual(res, {'b': None}) - self.assertDictEqual(kwargs, {'a': 1}) + kwargs = {"a": 1, "b": None} + res = popargs_to_dict(["b"], kwargs) + self.assertDictEqual(res, {"b": None}) + self.assertDictEqual(kwargs, {"a": 1}) def test_many_args_pop_some(self): - kwargs = {'a': 1, 'b': None, 'c': 3} - res = popargs_to_dict(['a', 'c'], kwargs) - self.assertDictEqual(res, {'a': 1, 'c': 3}) - self.assertDictEqual(kwargs, {'b': None}) + kwargs = {"a": 1, "b": None, "c": 3} + res = popargs_to_dict(["a", "c"], kwargs) + self.assertDictEqual(res, {"a": 1, "c": 3}) + self.assertDictEqual(kwargs, {"b": None}) def test_many_args_pop_all(self): - kwargs = {'a': 1, 'b': None, 'c': 3} - res = popargs_to_dict(['a', 'b', 'c'], kwargs) - self.assertDictEqual(res, {'a': 1, 'b': None, 'c': 3}) + kwargs = {"a": 1, "b": None, "c": 3} + res = popargs_to_dict(["a", "b", "c"], kwargs) + self.assertDictEqual(res, {"a": 1, "b": None, "c": 3}) self.assertDictEqual(kwargs, {}) def test_arg_not_found_one_arg(self): - kwargs = {'a': 1, 'b': None} + kwargs = {"a": 1, "b": None} msg = "Argument not found in dict: 'c'" with self.assertRaisesWith(ValueError, msg): - popargs_to_dict(['c'], kwargs) + popargs_to_dict(["c"], kwargs) class TestMacro(TestCase): - def test_macro(self): self.assertTrue(isinstance(get_docval_macro(), dict)) - self.assertSetEqual(set(get_docval_macro().keys()), {'array_data', 'scalar_data', 'data'}) + self.assertSetEqual( + set(get_docval_macro().keys()), + {"array_data", "scalar_data", "data"}, + ) - self.assertTupleEqual(get_docval_macro('scalar_data'), (str, int, float, bytes, bool)) + self.assertTupleEqual(get_docval_macro("scalar_data"), (str, int, float, bytes, bool)) - @docval_macro('scalar_data') + @docval_macro("scalar_data") class Dummy1: pass - self.assertTupleEqual(get_docval_macro('scalar_data'), (str, int, float, bytes, bool, Dummy1)) + self.assertTupleEqual( + get_docval_macro("scalar_data"), + (str, int, float, bytes, bool, Dummy1), + ) - @docval_macro('dummy') + @docval_macro("dummy") class Dummy2: pass - self.assertTupleEqual(get_docval_macro('dummy'), (Dummy2, )) + self.assertTupleEqual(get_docval_macro("dummy"), (Dummy2,)) diff --git a/tests/unit/utils_test/test_labelleddict.py b/tests/unit/utils_test/test_labelleddict.py index 325975985..d3cc21d99 100644 --- a/tests/unit/utils_test/test_labelleddict.py +++ b/tests/unit/utils_test/test_labelleddict.py @@ -3,7 +3,6 @@ class MyTestClass: - def __init__(self, prop1, prop2): self._prop1 = prop1 self._prop2 = prop2 @@ -18,110 +17,109 @@ def prop2(self): class TestLabelledDict(TestCase): - def test_constructor(self): """Test that constructor sets arguments properly.""" - ld = LabelledDict(label='all_objects', key_attr='prop1') - self.assertEqual(ld.label, 'all_objects') - self.assertEqual(ld.key_attr, 'prop1') + ld = LabelledDict(label="all_objects", key_attr="prop1") + self.assertEqual(ld.label, "all_objects") + self.assertEqual(ld.key_attr, "prop1") def test_constructor_default(self): """Test that constructor sets default key attribute.""" - ld = LabelledDict(label='all_objects') - self.assertEqual(ld.key_attr, 'name') + ld = LabelledDict(label="all_objects") + self.assertEqual(ld.key_attr, "name") def test_set_key_attr(self): """Test that the key attribute cannot be set after initialization.""" - ld = LabelledDict(label='all_objects') + ld = LabelledDict(label="all_objects") with self.assertRaises(AttributeError): - ld.key_attr = 'another_name' + ld.key_attr = "another_name" def test_getitem_unknown_val(self): """Test that dict[unknown_key] where the key unknown_key is not in the dict raises an error.""" - ld = LabelledDict(label='all_objects', key_attr='prop1') + ld = LabelledDict(label="all_objects", key_attr="prop1") with self.assertRaises(KeyError): - ld['unknown_key'] + ld["unknown_key"] def test_getitem_eqeq_unknown_val(self): """Test that dict[unknown_attr == val] where there are no query matches returns an empty set.""" - ld = LabelledDict(label='all_objects', key_attr='prop1') - self.assertSetEqual(ld['unknown_attr == val'], set()) + ld = LabelledDict(label="all_objects", key_attr="prop1") + self.assertSetEqual(ld["unknown_attr == val"], set()) def test_getitem_eqeq_other_key(self): """Test that dict[other_attr == val] where there are no query matches returns an empty set.""" - ld = LabelledDict(label='all_objects', key_attr='prop1') - self.assertSetEqual(ld['prop2 == val'], set()) + ld = LabelledDict(label="all_objects", key_attr="prop1") + self.assertSetEqual(ld["prop2 == val"], set()) def test_getitem_eqeq_no_key_attr(self): """Test that dict[key_attr == val] raises an error if key_attr is not given.""" - ld = LabelledDict(label='all_objects', key_attr='prop1') + ld = LabelledDict(label="all_objects", key_attr="prop1") with self.assertRaisesWith(ValueError, "An attribute name is required before '=='."): - ld[' == unknown_key'] + ld[" == unknown_key"] def test_getitem_eqeq_no_val(self): """Test that dict[key_attr == val] raises an error if val is not given.""" - ld = LabelledDict(label='all_objects', key_attr='prop1') + ld = LabelledDict(label="all_objects", key_attr="prop1") with self.assertRaisesWith(ValueError, "A value is required after '=='."): - ld['prop1 == '] + ld["prop1 == "] def test_getitem_eqeq_no_key_attr_no_val(self): """Test that dict[key_attr == val] raises an error if key_attr is not given and val is not given.""" - ld = LabelledDict(label='all_objects', key_attr='prop1') + ld = LabelledDict(label="all_objects", key_attr="prop1") with self.assertRaisesWith(ValueError, "An attribute name is required before '=='."): - ld[' == '] + ld[" == "] def test_add_basic(self): """Test add method on object with correct key_attr.""" - ld = LabelledDict(label='all_objects', key_attr='prop1') - obj1 = MyTestClass('a', 'b') + ld = LabelledDict(label="all_objects", key_attr="prop1") + obj1 = MyTestClass("a", "b") ld.add(obj1) - self.assertIs(ld['a'], obj1) + self.assertIs(ld["a"], obj1) def test_add_value_missing_key(self): """Test that add raises an error if the value being set does not have the attribute key_attr.""" - ld = LabelledDict(label='all_objects', key_attr='unknown_key') - obj1 = MyTestClass('a', 'b') + ld = LabelledDict(label="all_objects", key_attr="unknown_key") + obj1 = MyTestClass("a", "b") - err_msg = r"Cannot set value '<.*>' in LabelledDict\. Value must have attribute 'unknown_key'\." + err_msg = r"Cannot set value '<.*>' in LabelledDict\. Value must have attribute" r" 'unknown_key'\." with self.assertRaisesRegex(ValueError, err_msg): ld.add(obj1) def test_setitem_getitem_basic(self): """Test that setitem and getitem properly set and get the object.""" - ld = LabelledDict(label='all_objects', key_attr='prop1') - obj1 = MyTestClass('a', 'b') + ld = LabelledDict(label="all_objects", key_attr="prop1") + obj1 = MyTestClass("a", "b") ld.add(obj1) - self.assertIs(ld['a'], obj1) + self.assertIs(ld["a"], obj1) def test_setitem_value_missing_key(self): """Test that setitem raises an error if the value being set does not have the attribute key_attr.""" - ld = LabelledDict(label='all_objects', key_attr='unknown_key') - obj1 = MyTestClass('a', 'b') + ld = LabelledDict(label="all_objects", key_attr="unknown_key") + obj1 = MyTestClass("a", "b") - err_msg = r"Cannot set value '<.*>' in LabelledDict\. Value must have attribute 'unknown_key'\." + err_msg = r"Cannot set value '<.*>' in LabelledDict\. Value must have attribute" r" 'unknown_key'\." with self.assertRaisesRegex(ValueError, err_msg): - ld['a'] = obj1 + ld["a"] = obj1 def test_setitem_value_inconsistent_key(self): """Test that setitem raises an error if the value being set has an inconsistent key.""" - ld = LabelledDict(label='all_objects', key_attr='prop1') - obj1 = MyTestClass('a', 'b') + ld = LabelledDict(label="all_objects", key_attr="prop1") + obj1 = MyTestClass("a", "b") err_msg = r"Key 'b' must equal attribute 'prop1' of '<.*>'\." with self.assertRaisesRegex(KeyError, err_msg): - ld['b'] = obj1 + ld["b"] = obj1 def test_setitem_value_duplicate_key(self): """Test that setitem raises an error if the key already exists in the dict.""" - ld = LabelledDict(label='all_objects', key_attr='prop1') - obj1 = MyTestClass('a', 'b') - obj2 = MyTestClass('a', 'c') + ld = LabelledDict(label="all_objects", key_attr="prop1") + obj1 = MyTestClass("a", "b") + obj2 = MyTestClass("a", "c") - ld['a'] = obj1 + ld["a"] = obj1 err_msg = "Key 'a' is already in this dict. Cannot reset items in a LabelledDict." with self.assertRaisesWith(TypeError, err_msg): - ld['a'] = obj2 + ld["a"] = obj2 def test_add_callable(self): """Test that add properly adds the object and calls the add_callable function.""" @@ -130,10 +128,10 @@ def test_add_callable(self): def func(v): self.signal = v - ld = LabelledDict(label='all_objects', key_attr='prop1', add_callable=func) - obj1 = MyTestClass('a', 'b') + ld = LabelledDict(label="all_objects", key_attr="prop1", add_callable=func) + obj1 = MyTestClass("a", "b") ld.add(obj1) - self.assertIs(ld['a'], obj1) + self.assertIs(ld["a"], obj1) self.assertIs(self.signal, obj1) def test_setitem_callable(self): @@ -143,51 +141,51 @@ def test_setitem_callable(self): def func(v): self.signal = v - ld = LabelledDict(label='all_objects', key_attr='prop1', add_callable=func) - obj1 = MyTestClass('a', 'b') - ld['a'] = obj1 - self.assertIs(ld['a'], obj1) + ld = LabelledDict(label="all_objects", key_attr="prop1", add_callable=func) + obj1 = MyTestClass("a", "b") + ld["a"] = obj1 + self.assertIs(ld["a"], obj1) self.assertIs(self.signal, obj1) def test_getitem_eqeq_nonempty(self): """Test that dict[key_attr == val] returns the single matching object.""" - ld = LabelledDict(label='all_objects', key_attr='prop1') - obj1 = MyTestClass('a', 'b') + ld = LabelledDict(label="all_objects", key_attr="prop1") + obj1 = MyTestClass("a", "b") ld.add(obj1) - self.assertIs(ld['prop1 == a'], obj1) + self.assertIs(ld["prop1 == a"], obj1) def test_getitem_eqeq_nonempty_key_attr_no_match(self): """Test that dict[key_attr == unknown_val] where a matching value is not found raises a KeyError.""" - ld = LabelledDict(label='all_objects', key_attr='prop1') - obj1 = MyTestClass('a', 'b') + ld = LabelledDict(label="all_objects", key_attr="prop1") + obj1 = MyTestClass("a", "b") ld.add(obj1) with self.assertRaises(KeyError): - ld['prop1 == unknown_val'] # same as ld['unknown_val'] + ld["prop1 == unknown_val"] # same as ld['unknown_val'] def test_getitem_eqeq_nonempty_unknown_attr(self): """Test that dict[unknown_attr == val] where unknown_attr is not a field on the values raises an error.""" - ld = LabelledDict(label='all_objects', key_attr='prop1') - obj1 = MyTestClass('a', 'b') - ld['a'] = obj1 - self.assertSetEqual(ld['unknown_attr == unknown_val'], set()) + ld = LabelledDict(label="all_objects", key_attr="prop1") + obj1 = MyTestClass("a", "b") + ld["a"] = obj1 + self.assertSetEqual(ld["unknown_attr == unknown_val"], set()) def test_getitem_nonempty_other_key(self): """Test that dict[other_key == val] returns a set of matching objects.""" - ld = LabelledDict(label='all_objects', key_attr='prop1') - obj1 = MyTestClass('a', 'b') - obj2 = MyTestClass('d', 'b') - obj3 = MyTestClass('f', 'e') + ld = LabelledDict(label="all_objects", key_attr="prop1") + obj1 = MyTestClass("a", "b") + obj2 = MyTestClass("d", "b") + obj3 = MyTestClass("f", "e") ld.add(obj1) ld.add(obj2) ld.add(obj3) - self.assertSetEqual(ld['prop2 == b'], {obj1, obj2}) + self.assertSetEqual(ld["prop2 == b"], {obj1, obj2}) def test_pop_nocallback(self): - ld = LabelledDict(label='all_objects', key_attr='prop1') - obj1 = MyTestClass('a', 'b') + ld = LabelledDict(label="all_objects", key_attr="prop1") + obj1 = MyTestClass("a", "b") ld.add(obj1) - ret = ld.pop('a') + ret = ld.pop("a") self.assertEqual(ret, obj1) self.assertEqual(ld, dict()) @@ -197,22 +195,22 @@ def test_pop_callback(self): def func(v): self.signal = v - ld = LabelledDict(label='all_objects', key_attr='prop1', remove_callable=func) - obj1 = MyTestClass('a', 'b') + ld = LabelledDict(label="all_objects", key_attr="prop1", remove_callable=func) + obj1 = MyTestClass("a", "b") ld.add(obj1) - ret = ld.pop('a') + ret = ld.pop("a") self.assertEqual(ret, obj1) self.assertEqual(self.signal, obj1) self.assertEqual(ld, dict()) def test_popitem_nocallback(self): - ld = LabelledDict(label='all_objects', key_attr='prop1') - obj1 = MyTestClass('a', 'b') + ld = LabelledDict(label="all_objects", key_attr="prop1") + obj1 = MyTestClass("a", "b") ld.add(obj1) ret = ld.popitem() - self.assertEqual(ret, ('a', obj1)) + self.assertEqual(ret, ("a", obj1)) self.assertEqual(ld, dict()) def test_popitem_callback(self): @@ -221,19 +219,19 @@ def test_popitem_callback(self): def func(v): self.signal = v - ld = LabelledDict(label='all_objects', key_attr='prop1', remove_callable=func) - obj1 = MyTestClass('a', 'b') + ld = LabelledDict(label="all_objects", key_attr="prop1", remove_callable=func) + obj1 = MyTestClass("a", "b") ld.add(obj1) ret = ld.popitem() - self.assertEqual(ret, ('a', obj1)) + self.assertEqual(ret, ("a", obj1)) self.assertEqual(self.signal, obj1) self.assertEqual(ld, dict()) def test_clear_nocallback(self): - ld = LabelledDict(label='all_objects', key_attr='prop1') - obj1 = MyTestClass('a', 'b') - obj2 = MyTestClass('d', 'b') + ld = LabelledDict(label="all_objects", key_attr="prop1") + obj1 = MyTestClass("a", "b") + obj2 = MyTestClass("d", "b") ld.add(obj1) ld.add(obj2) ld.clear() @@ -245,9 +243,9 @@ def test_clear_callback(self): def func(v): self.signal.add(v) - ld = LabelledDict(label='all_objects', key_attr='prop1', remove_callable=func) - obj1 = MyTestClass('a', 'b') - obj2 = MyTestClass('d', 'b') + ld = LabelledDict(label="all_objects", key_attr="prop1", remove_callable=func) + obj1 = MyTestClass("a", "b") + obj2 = MyTestClass("d", "b") ld.add(obj1) ld.add(obj2) ld.clear() @@ -255,11 +253,11 @@ def func(v): self.assertEqual(ld, dict()) def test_delitem_nocallback(self): - ld = LabelledDict(label='all_objects', key_attr='prop1') - obj1 = MyTestClass('a', 'b') + ld = LabelledDict(label="all_objects", key_attr="prop1") + obj1 = MyTestClass("a", "b") ld.add(obj1) - del ld['a'] + del ld["a"] self.assertEqual(ld, dict()) def test_delitem_callback(self): @@ -268,22 +266,22 @@ def test_delitem_callback(self): def func(v): self.signal = v - ld = LabelledDict(label='all_objects', key_attr='prop1', remove_callable=func) - obj1 = MyTestClass('a', 'b') + ld = LabelledDict(label="all_objects", key_attr="prop1", remove_callable=func) + obj1 = MyTestClass("a", "b") ld.add(obj1) - del ld['a'] + del ld["a"] self.assertEqual(self.signal, obj1) self.assertEqual(ld, dict()) def test_update_callback(self): - ld = LabelledDict(label='all_objects', key_attr='prop1') + ld = LabelledDict(label="all_objects", key_attr="prop1") with self.assertRaisesWith(TypeError, "update is not supported for LabelledDict"): ld.update(object()) def test_setdefault_callback(self): - ld = LabelledDict(label='all_objects', key_attr='prop1') + ld = LabelledDict(label="all_objects", key_attr="prop1") with self.assertRaisesWith(TypeError, "setdefault is not supported for LabelledDict"): ld.setdefault(object()) diff --git a/tests/unit/utils_test/test_utils.py b/tests/unit/utils_test/test_utils.py index b0cd05e9d..c4924df08 100644 --- a/tests/unit/utils_test/test_utils.py +++ b/tests/unit/utils_test/test_utils.py @@ -2,27 +2,27 @@ import h5py import numpy as np + from hdmf.data_utils import DataChunkIterator, DataIO from hdmf.testing import TestCase from hdmf.utils import get_data_shape, to_uint_array class TestGetDataShape(TestCase): - def test_h5dataset(self): """Test get_data_shape on h5py.Datasets of various shapes and maxshape.""" - path = 'test_get_data_shape.h5' - with h5py.File(path, 'w') as f: - dset = f.create_dataset('data', data=((1, 2), (3, 4), (5, 6))) + path = "test_get_data_shape.h5" + with h5py.File(path, "w") as f: + dset = f.create_dataset("data", data=((1, 2), (3, 4), (5, 6))) res = get_data_shape(dset) self.assertTupleEqual(res, (3, 2)) - dset = f.create_dataset('shape', shape=(3, 2)) + dset = f.create_dataset("shape", shape=(3, 2)) res = get_data_shape(dset) self.assertTupleEqual(res, (3, 2)) # test that maxshape takes priority - dset = f.create_dataset('shape_maxshape', shape=(3, 2), maxshape=(None, 100)) + dset = f.create_dataset("shape_maxshape", shape=(3, 2), maxshape=(None, 100)) res = get_data_shape(dset) self.assertTupleEqual(res, (None, 100)) @@ -36,7 +36,7 @@ def test_dci(self): dci = DataChunkIterator(data=[1, 2]) res = get_data_shape(dci) - self.assertTupleEqual(res, (2, )) + self.assertTupleEqual(res, (2,)) dci = DataChunkIterator(data=[[1, 2], [3, 4], [5, 6]]) res = get_data_shape(dci) @@ -51,7 +51,7 @@ def test_dataio(self): """Test get_data_shape on DataIO of various shapes and maxshape.""" dio = DataIO(data=[1, 2]) res = get_data_shape(dio) - self.assertTupleEqual(res, (2, )) + self.assertTupleEqual(res, (2,)) dio = DataIO(data=[[1, 2], [3, 4], [5, 6]]) res = get_data_shape(dio) @@ -64,10 +64,10 @@ def test_dataio(self): def test_list(self): """Test get_data_shape on lists of various shapes.""" res = get_data_shape(list()) - self.assertTupleEqual(res, (0, )) + self.assertTupleEqual(res, (0,)) res = get_data_shape([1, 2]) - self.assertTupleEqual(res, (2, )) + self.assertTupleEqual(res, (2,)) res = get_data_shape([[1, 2], [3, 4], [5, 6]]) self.assertTupleEqual(res, (3, 2)) @@ -75,10 +75,10 @@ def test_list(self): def test_tuple(self): """Test get_data_shape on tuples of various shapes.""" res = get_data_shape(tuple()) - self.assertTupleEqual(res, (0, )) + self.assertTupleEqual(res, (0,)) res = get_data_shape((1, 2)) - self.assertTupleEqual(res, (2, )) + self.assertTupleEqual(res, (2,)) res = get_data_shape(((1, 2), (3, 4), (5, 6))) self.assertTupleEqual(res, (3, 2)) @@ -89,10 +89,10 @@ def test_nparray(self): self.assertTupleEqual(res, tuple()) res = get_data_shape(np.array([])) - self.assertTupleEqual(res, (0, )) + self.assertTupleEqual(res, (0,)) res = get_data_shape(np.array([1, 2])) - self.assertTupleEqual(res, (2, )) + self.assertTupleEqual(res, (2,)) res = get_data_shape(np.array([[1, 2], [3, 4], [5, 6]])) self.assertTupleEqual(res, (3, 2)) @@ -106,32 +106,32 @@ def test_other(self): self.assertIsNone(res) res = get_data_shape([None, None]) - self.assertTupleEqual(res, (2, )) + self.assertTupleEqual(res, (2,)) res = get_data_shape(object()) self.assertIsNone(res) res = get_data_shape([object(), object()]) - self.assertTupleEqual(res, (2, )) + self.assertTupleEqual(res, (2,)) def test_string(self): """Test get_data_shape on strings and collections of strings.""" - res = get_data_shape('abc') + res = get_data_shape("abc") self.assertIsNone(res) - res = get_data_shape(('a', 'b')) - self.assertTupleEqual(res, (2, )) + res = get_data_shape(("a", "b")) + self.assertTupleEqual(res, (2,)) - res = get_data_shape((('a', 'b'), ('c', 'd'), ('e', 'f'))) + res = get_data_shape((("a", "b"), ("c", "d"), ("e", "f"))) self.assertTupleEqual(res, (3, 2)) def test_set(self): """Test get_data_shape on sets, which have __len__ but are not subscriptable.""" res = get_data_shape(set()) - self.assertTupleEqual(res, (0, )) + self.assertTupleEqual(res, (0,)) res = get_data_shape({1, 2}) - self.assertTupleEqual(res, (2, )) + self.assertTupleEqual(res, (2,)) def test_arbitrary_iterable_with_len(self): """Test get_data_shape with strict_no_data_load=True on an arbitrary iterable object with __len__.""" @@ -168,7 +168,6 @@ def test_strict_no_data_load(self): class TestToUintArray(TestCase): - def test_ndarray_uint(self): arr = np.array([0, 1, 2], dtype=np.uint32) res = to_uint_array(arr) @@ -181,12 +180,12 @@ def test_ndarray_int(self): def test_ndarray_int_neg(self): arr = np.array([0, -1, 2], dtype=np.int32) - with self.assertRaisesWith(ValueError, 'Cannot convert negative integer values to uint.'): + with self.assertRaisesWith(ValueError, "Cannot convert negative integer values to uint."): to_uint_array(arr) def test_ndarray_float(self): arr = np.array([0, 1, 2], dtype=np.float64) - with self.assertRaisesWith(ValueError, 'Cannot convert array of dtype float64 to uint.'): + with self.assertRaisesWith(ValueError, "Cannot convert array of dtype float64 to uint."): to_uint_array(arr) def test_list_int(self): @@ -197,10 +196,10 @@ def test_list_int(self): def test_list_int_neg(self): arr = [0, -1, 2] - with self.assertRaisesWith(ValueError, 'Cannot convert negative integer values to uint.'): + with self.assertRaisesWith(ValueError, "Cannot convert negative integer values to uint."): to_uint_array(arr) def test_list_float(self): - arr = [0., 1., 2.] - with self.assertRaisesWith(ValueError, 'Cannot convert array of dtype float64 to uint.'): + arr = [0.0, 1.0, 2.0] + with self.assertRaisesWith(ValueError, "Cannot convert array of dtype float64 to uint."): to_uint_array(arr) diff --git a/tests/unit/validator_tests/test_errors.py b/tests/unit/validator_tests/test_errors.py index 3ae6aea8f..12c67d426 100644 --- a/tests/unit/validator_tests/test_errors.py +++ b/tests/unit/validator_tests/test_errors.py @@ -6,49 +6,49 @@ class TestErrorEquality(TestCase): def test_self_equality(self): """Verify that one error equals itself""" - error = Error('foo', 'bad thing', 'a.b.c') + error = Error("foo", "bad thing", "a.b.c") self.assertEqual(error, error) def test_equality_with_same_field_values(self): """Verify that two errors with the same field values are equal""" - err1 = Error('foo', 'bad thing', 'a.b.c') - err2 = Error('foo', 'bad thing', 'a.b.c') + err1 = Error("foo", "bad thing", "a.b.c") + err2 = Error("foo", "bad thing", "a.b.c") self.assertEqual(err1, err2) def test_not_equal_with_different_reason(self): """Verify that two errors with a different reason are not equal""" - err1 = Error('foo', 'bad thing', 'a.b.c') - err2 = Error('foo', 'something else', 'a.b.c') + err1 = Error("foo", "bad thing", "a.b.c") + err2 = Error("foo", "something else", "a.b.c") self.assertNotEqual(err1, err2) def test_not_equal_with_different_name(self): """Verify that two errors with a different name are not equal""" - err1 = Error('foo', 'bad thing', 'a.b.c') - err2 = Error('bar', 'bad thing', 'a.b.c') + err1 = Error("foo", "bad thing", "a.b.c") + err2 = Error("bar", "bad thing", "a.b.c") self.assertNotEqual(err1, err2) def test_not_equal_with_different_location(self): """Verify that two errors with a different location are not equal""" - err1 = Error('foo', 'bad thing', 'a.b.c') - err2 = Error('foo', 'bad thing', 'd.e.f') + err1 = Error("foo", "bad thing", "a.b.c") + err2 = Error("foo", "bad thing", "d.e.f") self.assertNotEqual(err1, err2) def test_equal_with_no_location(self): """Verify that two errors with no location but the same name are equal""" - err1 = Error('foo', 'bad thing') - err2 = Error('foo', 'bad thing') + err1 = Error("foo", "bad thing") + err2 = Error("foo", "bad thing") self.assertEqual(err1, err2) def test_not_equal_with_overlapping_name_when_no_location(self): """Verify that two errors with an overlapping name but no location are not equal """ - err1 = Error('foo', 'bad thing') - err2 = Error('x/y/foo', 'bad thing') + err1 = Error("foo", "bad thing") + err2 = Error("x/y/foo", "bad thing") self.assertNotEqual(err1, err2) def test_equal_with_overlapping_name_when_location_present(self): """Verify that two errors with an overlapping name and a location are equal""" - err1 = Error('foo', 'bad thing', 'a.b.c') - err2 = Error('x/y/foo', 'bad thing', 'a.b.c') + err1 = Error("foo", "bad thing", "a.b.c") + err2 = Error("x/y/foo", "bad thing", "a.b.c") self.assertEqual(err1, err2) diff --git a/tests/unit/validator_tests/test_validate.py b/tests/unit/validator_tests/test_validate.py index 506f9edac..27a752031 100644 --- a/tests/unit/validator_tests/test_validate.py +++ b/tests/unit/validator_tests/test_validate.py @@ -4,27 +4,54 @@ import numpy as np from dateutil.tz import tzlocal -from hdmf.build import GroupBuilder, DatasetBuilder, LinkBuilder, ReferenceBuilder, TypeMap, BuildManager -from hdmf.spec import (GroupSpec, AttributeSpec, DatasetSpec, SpecCatalog, SpecNamespace, - LinkSpec, RefSpec, NamespaceCatalog, DtypeSpec) + +from hdmf.backends.hdf5 import HDF5IO +from hdmf.build import ( + BuildManager, + DatasetBuilder, + GroupBuilder, + LinkBuilder, + ReferenceBuilder, + TypeMap, +) +from hdmf.spec import ( + AttributeSpec, + DatasetSpec, + DtypeSpec, + GroupSpec, + LinkSpec, + NamespaceCatalog, + RefSpec, + SpecCatalog, + SpecNamespace, +) from hdmf.spec.spec import ONE_OR_MANY, ZERO_OR_MANY, ZERO_OR_ONE from hdmf.testing import TestCase, remove_test_file from hdmf.validate import ValidatorMap -from hdmf.validate.errors import (DtypeError, MissingError, ExpectedArrayError, MissingDataType, - IncorrectQuantityError, IllegalLinkError) -from hdmf.backends.hdf5 import HDF5IO +from hdmf.validate.errors import ( + DtypeError, + ExpectedArrayError, + IllegalLinkError, + IncorrectQuantityError, + MissingDataType, + MissingError, +) -CORE_NAMESPACE = 'test_core' +CORE_NAMESPACE = "test_core" class ValidatorTestBase(TestCase, metaclass=ABCMeta): - def setUp(self): spec_catalog = SpecCatalog() for spec in self.getSpecs(): - spec_catalog.register_spec(spec, 'test.yaml') + spec_catalog.register_spec(spec, "test.yaml") self.namespace = SpecNamespace( - 'a test namespace', CORE_NAMESPACE, [{'source': 'test.yaml'}], version='0.1.0', catalog=spec_catalog) + "a test namespace", + CORE_NAMESPACE, + [{"source": "test.yaml"}], + version="0.1.0", + catalog=spec_catalog, + ) self.vmap = ValidatorMap(self.namespace) @abstractmethod @@ -41,230 +68,320 @@ def assertValidationError(self, error, type_, name=None, reason=None): class TestEmptySpec(ValidatorTestBase): - def getSpecs(self): - return (GroupSpec('A test group specification with a data type', data_type_def='Bar'),) + return ( + GroupSpec( + "A test group specification with a data type", + data_type_def="Bar", + ), + ) def test_valid(self): - builder = GroupBuilder('my_bar', attributes={'data_type': 'Bar'}) - validator = self.vmap.get_validator('Bar') + builder = GroupBuilder("my_bar", attributes={"data_type": "Bar"}) + validator = self.vmap.get_validator("Bar") result = validator.validate(builder) self.assertEqual(len(result), 0) def test_invalid_missing_req_type(self): - builder = GroupBuilder('my_bar') + builder = GroupBuilder("my_bar") err_msg = r"builder must have data type defined with attribute '[A-Za-z_]+'" with self.assertRaisesRegex(ValueError, err_msg): self.vmap.validate(builder) class TestBasicSpec(ValidatorTestBase): - def getSpecs(self): - ret = GroupSpec('A test group specification with a data type', - data_type_def='Bar', - datasets=[DatasetSpec('an example dataset', 'int', name='data', - attributes=[AttributeSpec( - 'attr2', 'an example integer attribute', 'int')])], - attributes=[AttributeSpec('attr1', 'an example string attribute', 'text')]) + ret = GroupSpec( + "A test group specification with a data type", + data_type_def="Bar", + datasets=[ + DatasetSpec( + "an example dataset", + "int", + name="data", + attributes=[AttributeSpec("attr2", "an example integer attribute", "int")], + ) + ], + attributes=[AttributeSpec("attr1", "an example string attribute", "text")], + ) return (ret,) def test_invalid_missing(self): - builder = GroupBuilder('my_bar', attributes={'data_type': 'Bar'}) - validator = self.vmap.get_validator('Bar') + builder = GroupBuilder("my_bar", attributes={"data_type": "Bar"}) + validator = self.vmap.get_validator("Bar") result = validator.validate(builder) self.assertEqual(len(result), 2) - self.assertValidationError(result[0], MissingError, name='Bar/attr1') - self.assertValidationError(result[1], MissingError, name='Bar/data') + self.assertValidationError(result[0], MissingError, name="Bar/attr1") + self.assertValidationError(result[1], MissingError, name="Bar/data") def test_invalid_incorrect_type_get_validator(self): - builder = GroupBuilder('my_bar', attributes={'data_type': 'Bar', 'attr1': 10}) - validator = self.vmap.get_validator('Bar') + builder = GroupBuilder("my_bar", attributes={"data_type": "Bar", "attr1": 10}) + validator = self.vmap.get_validator("Bar") result = validator.validate(builder) self.assertEqual(len(result), 2) - self.assertValidationError(result[0], DtypeError, name='Bar/attr1') - self.assertValidationError(result[1], MissingError, name='Bar/data') + self.assertValidationError(result[0], DtypeError, name="Bar/attr1") + self.assertValidationError(result[1], MissingError, name="Bar/data") def test_invalid_incorrect_type_validate(self): - builder = GroupBuilder('my_bar', attributes={'data_type': 'Bar', 'attr1': 10}) + builder = GroupBuilder("my_bar", attributes={"data_type": "Bar", "attr1": 10}) result = self.vmap.validate(builder) self.assertEqual(len(result), 2) - self.assertValidationError(result[0], DtypeError, name='Bar/attr1') - self.assertValidationError(result[1], MissingError, name='Bar/data') + self.assertValidationError(result[0], DtypeError, name="Bar/attr1") + self.assertValidationError(result[1], MissingError, name="Bar/data") def test_valid(self): - builder = GroupBuilder('my_bar', - attributes={'data_type': 'Bar', 'attr1': 'a string attribute'}, - datasets=[DatasetBuilder('data', 100, attributes={'attr2': 10})]) - validator = self.vmap.get_validator('Bar') + builder = GroupBuilder( + "my_bar", + attributes={"data_type": "Bar", "attr1": "a string attribute"}, + datasets=[DatasetBuilder("data", 100, attributes={"attr2": 10})], + ) + validator = self.vmap.get_validator("Bar") result = validator.validate(builder) self.assertEqual(len(result), 0) class TestDateTimeInSpec(ValidatorTestBase): - def getSpecs(self): - ret = GroupSpec('A test group specification with a data type', - data_type_def='Bar', - datasets=[DatasetSpec('an example dataset', 'int', name='data', - attributes=[AttributeSpec( - 'attr2', 'an example integer attribute', 'int')]), - DatasetSpec('an example time dataset', 'isodatetime', name='time'), - DatasetSpec('an array of times', 'isodatetime', name='time_array', - dims=('num_times',), shape=(None,))], - attributes=[AttributeSpec('attr1', 'an example string attribute', 'text')]) + ret = GroupSpec( + "A test group specification with a data type", + data_type_def="Bar", + datasets=[ + DatasetSpec( + "an example dataset", + "int", + name="data", + attributes=[AttributeSpec("attr2", "an example integer attribute", "int")], + ), + DatasetSpec("an example time dataset", "isodatetime", name="time"), + DatasetSpec( + "an array of times", + "isodatetime", + name="time_array", + dims=("num_times",), + shape=(None,), + ), + ], + attributes=[AttributeSpec("attr1", "an example string attribute", "text")], + ) return (ret,) def test_valid_isodatetime(self): - builder = GroupBuilder('my_bar', - attributes={'data_type': 'Bar', 'attr1': 'a string attribute'}, - datasets=[DatasetBuilder('data', 100, attributes={'attr2': 10}), - DatasetBuilder('time', - datetime(2017, 5, 1, 12, 0, 0, tzinfo=tzlocal())), - DatasetBuilder('time_array', - [datetime(2017, 5, 1, 12, 0, 0, tzinfo=tzlocal())])]) - validator = self.vmap.get_validator('Bar') + builder = GroupBuilder( + "my_bar", + attributes={"data_type": "Bar", "attr1": "a string attribute"}, + datasets=[ + DatasetBuilder("data", 100, attributes={"attr2": 10}), + DatasetBuilder("time", datetime(2017, 5, 1, 12, 0, 0, tzinfo=tzlocal())), + DatasetBuilder( + "time_array", + [datetime(2017, 5, 1, 12, 0, 0, tzinfo=tzlocal())], + ), + ], + ) + validator = self.vmap.get_validator("Bar") result = validator.validate(builder) self.assertEqual(len(result), 0) def test_invalid_isodatetime(self): - builder = GroupBuilder('my_bar', - attributes={'data_type': 'Bar', 'attr1': 'a string attribute'}, - datasets=[DatasetBuilder('data', 100, attributes={'attr2': 10}), - DatasetBuilder('time', 100), - DatasetBuilder('time_array', - [datetime(2017, 5, 1, 12, 0, 0, tzinfo=tzlocal())])]) - validator = self.vmap.get_validator('Bar') + builder = GroupBuilder( + "my_bar", + attributes={"data_type": "Bar", "attr1": "a string attribute"}, + datasets=[ + DatasetBuilder("data", 100, attributes={"attr2": 10}), + DatasetBuilder("time", 100), + DatasetBuilder( + "time_array", + [datetime(2017, 5, 1, 12, 0, 0, tzinfo=tzlocal())], + ), + ], + ) + validator = self.vmap.get_validator("Bar") result = validator.validate(builder) self.assertEqual(len(result), 1) - self.assertValidationError(result[0], DtypeError, name='Bar/time') + self.assertValidationError(result[0], DtypeError, name="Bar/time") def test_invalid_isodatetime_array(self): - builder = GroupBuilder('my_bar', - attributes={'data_type': 'Bar', 'attr1': 'a string attribute'}, - datasets=[DatasetBuilder('data', 100, attributes={'attr2': 10}), - DatasetBuilder('time', - datetime(2017, 5, 1, 12, 0, 0, tzinfo=tzlocal())), - DatasetBuilder('time_array', - datetime(2017, 5, 1, 12, 0, 0, tzinfo=tzlocal()))]) - validator = self.vmap.get_validator('Bar') + builder = GroupBuilder( + "my_bar", + attributes={"data_type": "Bar", "attr1": "a string attribute"}, + datasets=[ + DatasetBuilder("data", 100, attributes={"attr2": 10}), + DatasetBuilder("time", datetime(2017, 5, 1, 12, 0, 0, tzinfo=tzlocal())), + DatasetBuilder( + "time_array", + datetime(2017, 5, 1, 12, 0, 0, tzinfo=tzlocal()), + ), + ], + ) + validator = self.vmap.get_validator("Bar") result = validator.validate(builder) self.assertEqual(len(result), 1) - self.assertValidationError(result[0], ExpectedArrayError, name='Bar/time_array') + self.assertValidationError(result[0], ExpectedArrayError, name="Bar/time_array") class TestNestedTypes(ValidatorTestBase): - def getSpecs(self): - baz = DatasetSpec('A dataset with a data type', 'int', data_type_def='Baz', - attributes=[AttributeSpec('attr2', 'an example integer attribute', 'int')]) - bar = GroupSpec('A test group specification with a data type', - data_type_def='Bar', - datasets=[DatasetSpec('an example dataset', data_type_inc='Baz')], - attributes=[AttributeSpec('attr1', 'an example string attribute', 'text')]) - foo = GroupSpec('A test group that contains a data type', - data_type_def='Foo', - groups=[GroupSpec('A Bar group for Foos', name='my_bar', data_type_inc='Bar')], - attributes=[AttributeSpec('foo_attr', 'a string attribute specified as text', 'text', - required=False)]) + baz = DatasetSpec( + "A dataset with a data type", + "int", + data_type_def="Baz", + attributes=[AttributeSpec("attr2", "an example integer attribute", "int")], + ) + bar = GroupSpec( + "A test group specification with a data type", + data_type_def="Bar", + datasets=[DatasetSpec("an example dataset", data_type_inc="Baz")], + attributes=[AttributeSpec("attr1", "an example string attribute", "text")], + ) + foo = GroupSpec( + "A test group that contains a data type", + data_type_def="Foo", + groups=[GroupSpec("A Bar group for Foos", name="my_bar", data_type_inc="Bar")], + attributes=[ + AttributeSpec( + "foo_attr", + "a string attribute specified as text", + "text", + required=False, + ) + ], + ) return (bar, foo, baz) def test_invalid_missing_named_req_group(self): """Test that a MissingDataType is returned when a required named nested data type is missing.""" - foo_builder = GroupBuilder('my_foo', attributes={'data_type': 'Foo', - 'foo_attr': 'example Foo object'}) + foo_builder = GroupBuilder( + "my_foo", + attributes={"data_type": "Foo", "foo_attr": "example Foo object"}, + ) results = self.vmap.validate(foo_builder) self.assertEqual(len(results), 1) - self.assertValidationError(results[0], MissingDataType, name='Foo', - reason='missing data type Bar (my_bar)') + self.assertValidationError( + results[0], + MissingDataType, + name="Foo", + reason="missing data type Bar (my_bar)", + ) def test_invalid_wrong_name_req_type(self): """Test that a MissingDataType is returned when a required nested data type is given the wrong name.""" - bar_builder = GroupBuilder('bad_bar_name', - attributes={'data_type': 'Bar', 'attr1': 'a string attribute'}, - datasets=[DatasetBuilder('data', 100, attributes={'attr2': 10})]) + bar_builder = GroupBuilder( + "bad_bar_name", + attributes={"data_type": "Bar", "attr1": "a string attribute"}, + datasets=[DatasetBuilder("data", 100, attributes={"attr2": 10})], + ) - foo_builder = GroupBuilder('my_foo', - attributes={'data_type': 'Foo', 'foo_attr': 'example Foo object'}, - groups=[bar_builder]) + foo_builder = GroupBuilder( + "my_foo", + attributes={"data_type": "Foo", "foo_attr": "example Foo object"}, + groups=[bar_builder], + ) results = self.vmap.validate(foo_builder) self.assertEqual(len(results), 1) - self.assertValidationError(results[0], MissingDataType, name='Foo') - self.assertEqual(results[0].data_type, 'Bar') + self.assertValidationError(results[0], MissingDataType, name="Foo") + self.assertEqual(results[0].data_type, "Bar") def test_invalid_missing_unnamed_req_group(self): """Test that a MissingDataType is returned when a required unnamed nested data type is missing.""" - bar_builder = GroupBuilder('my_bar', - attributes={'data_type': 'Bar', 'attr1': 'a string attribute'}) + bar_builder = GroupBuilder( + "my_bar", + attributes={"data_type": "Bar", "attr1": "a string attribute"}, + ) - foo_builder = GroupBuilder('my_foo', - attributes={'data_type': 'Foo', 'foo_attr': 'example Foo object'}, - groups=[bar_builder]) + foo_builder = GroupBuilder( + "my_foo", + attributes={"data_type": "Foo", "foo_attr": "example Foo object"}, + groups=[bar_builder], + ) results = self.vmap.validate(foo_builder) self.assertEqual(len(results), 1) - self.assertValidationError(results[0], MissingDataType, name='Bar', - reason='missing data type Baz') + self.assertValidationError( + results[0], + MissingDataType, + name="Bar", + reason="missing data type Baz", + ) def test_valid(self): """Test that no errors are returned when nested data types are correctly built.""" - bar_builder = GroupBuilder('my_bar', - attributes={'data_type': 'Bar', 'attr1': 'a string attribute'}, - datasets=[DatasetBuilder('data', 100, attributes={'data_type': 'Baz', 'attr2': 10})]) + bar_builder = GroupBuilder( + "my_bar", + attributes={"data_type": "Bar", "attr1": "a string attribute"}, + datasets=[DatasetBuilder("data", 100, attributes={"data_type": "Baz", "attr2": 10})], + ) - foo_builder = GroupBuilder('my_foo', - attributes={'data_type': 'Foo', 'foo_attr': 'example Foo object'}, - groups=[bar_builder]) + foo_builder = GroupBuilder( + "my_foo", + attributes={"data_type": "Foo", "foo_attr": "example Foo object"}, + groups=[bar_builder], + ) results = self.vmap.validate(foo_builder) self.assertEqual(len(results), 0) def test_valid_wo_opt_attr(self): - """"Test that no errors are returned when an optional attribute is omitted from a group.""" - bar_builder = GroupBuilder('my_bar', - attributes={'data_type': 'Bar', 'attr1': 'a string attribute'}, - datasets=[DatasetBuilder('data', 100, attributes={'data_type': 'Baz', 'attr2': 10})]) - foo_builder = GroupBuilder('my_foo', - attributes={'data_type': 'Foo'}, - groups=[bar_builder]) + """ "Test that no errors are returned when an optional attribute is omitted from a group.""" + bar_builder = GroupBuilder( + "my_bar", + attributes={"data_type": "Bar", "attr1": "a string attribute"}, + datasets=[DatasetBuilder("data", 100, attributes={"data_type": "Baz", "attr2": 10})], + ) + foo_builder = GroupBuilder("my_foo", attributes={"data_type": "Foo"}, groups=[bar_builder]) results = self.vmap.validate(foo_builder) self.assertEqual(len(results), 0) class TestQuantityValidation(TestCase): - def create_test_specs(self, q_groups, q_datasets, q_links): - bar = GroupSpec('A test group', data_type_def='Bar') - baz = DatasetSpec('A test dataset', 'int', data_type_def='Baz') - qux = GroupSpec('A group to link', data_type_def='Qux') - foo = GroupSpec('A group containing a quantity of tests and datasets', - data_type_def='Foo', - groups=[GroupSpec('A bar', data_type_inc='Bar', quantity=q_groups)], - datasets=[DatasetSpec('A baz', data_type_inc='Baz', quantity=q_datasets)], - links=[LinkSpec('A qux', target_type='Qux', quantity=q_links)],) + bar = GroupSpec("A test group", data_type_def="Bar") + baz = DatasetSpec("A test dataset", "int", data_type_def="Baz") + qux = GroupSpec("A group to link", data_type_def="Qux") + foo = GroupSpec( + "A group containing a quantity of tests and datasets", + data_type_def="Foo", + groups=[GroupSpec("A bar", data_type_inc="Bar", quantity=q_groups)], + datasets=[DatasetSpec("A baz", data_type_inc="Baz", quantity=q_datasets)], + links=[LinkSpec("A qux", target_type="Qux", quantity=q_links)], + ) return (bar, foo, baz, qux) def configure_specs(self, specs): spec_catalog = SpecCatalog() for spec in specs: - spec_catalog.register_spec(spec, 'test.yaml') + spec_catalog.register_spec(spec, "test.yaml") self.namespace = SpecNamespace( - 'a test namespace', CORE_NAMESPACE, [{'source': 'test.yaml'}], version='0.1.0', catalog=spec_catalog) + "a test namespace", + CORE_NAMESPACE, + [{"source": "test.yaml"}], + version="0.1.0", + catalog=spec_catalog, + ) self.vmap = ValidatorMap(self.namespace) def get_test_builder(self, n_groups, n_datasets, n_links): - child_groups = [GroupBuilder(f'bar_{n}', attributes={'data_type': 'Bar'}) for n in range(n_groups)] - child_datasets = [DatasetBuilder(f'baz_{n}', n, attributes={'data_type': 'Baz'}) for n in range(n_datasets)] - child_links = [LinkBuilder(GroupBuilder(f'qux_{n}', attributes={'data_type': 'Qux'}), f'qux_{n}_link') - for n in range(n_links)] - return GroupBuilder('my_foo', attributes={'data_type': 'Foo'}, - groups=child_groups, datasets=child_datasets, links=child_links) + child_groups = [GroupBuilder(f"bar_{n}", attributes={"data_type": "Bar"}) for n in range(n_groups)] + child_datasets = [DatasetBuilder(f"baz_{n}", n, attributes={"data_type": "Baz"}) for n in range(n_datasets)] + child_links = [ + LinkBuilder( + GroupBuilder(f"qux_{n}", attributes={"data_type": "Qux"}), + f"qux_{n}_link", + ) + for n in range(n_links) + ] + return GroupBuilder( + "my_foo", + attributes={"data_type": "Foo"}, + groups=child_groups, + datasets=child_datasets, + links=child_links, + ) def test_valid_zero_or_many(self): - """"Verify that groups/datasets/links with ZERO_OR_MANY and a valid quantity correctly pass validation""" + """ "Verify that groups/datasets/links with ZERO_OR_MANY and a valid quantity correctly pass validation""" specs = self.create_test_specs(q_groups=ZERO_OR_MANY, q_datasets=ZERO_OR_MANY, q_links=ZERO_OR_MANY) self.configure_specs(specs) for n in [0, 1, 2, 5]: @@ -274,7 +391,7 @@ def test_valid_zero_or_many(self): self.assertEqual(len(results), 0) def test_valid_one_or_many(self): - """"Verify that groups/datasets/links with ONE_OR_MANY and a valid quantity correctly pass validation""" + """ "Verify that groups/datasets/links with ONE_OR_MANY and a valid quantity correctly pass validation""" specs = self.create_test_specs(q_groups=ONE_OR_MANY, q_datasets=ONE_OR_MANY, q_links=ONE_OR_MANY) self.configure_specs(specs) for n in [1, 2, 5]: @@ -284,7 +401,7 @@ def test_valid_one_or_many(self): self.assertEqual(len(results), 0) def test_valid_zero_or_one(self): - """"Verify that groups/datasets/links with ZERO_OR_ONE and a valid quantity correctly pass validation""" + """ "Verify that groups/datasets/links with ZERO_OR_ONE and a valid quantity correctly pass validation""" specs = self.create_test_specs(q_groups=ZERO_OR_ONE, q_datasets=ZERO_OR_ONE, q_links=ZERO_OR_ONE) self.configure_specs(specs) for n in [0, 1]: @@ -294,13 +411,15 @@ def test_valid_zero_or_one(self): self.assertEqual(len(results), 0) def test_valid_fixed_quantity(self): - """"Verify that groups/datasets/links with a correct fixed quantity correctly pass validation""" + """ "Verify that groups/datasets/links with a correct fixed quantity correctly pass validation""" self.configure_specs(self.create_test_specs(q_groups=2, q_datasets=3, q_links=5)) builder = self.get_test_builder(n_groups=2, n_datasets=3, n_links=5) results = self.vmap.validate(builder) self.assertEqual(len(results), 0) - def test_missing_one_or_many_should_not_return_incorrect_quantity_error(self): + def test_missing_one_or_many_should_not_return_incorrect_quantity_error( + self, + ): """Verify that missing ONE_OR_MANY groups/datasets/links should not return an IncorrectQuantityError NOTE: a MissingDataType error should be returned instead @@ -311,14 +430,18 @@ def test_missing_one_or_many_should_not_return_incorrect_quantity_error(self): results = self.vmap.validate(builder) self.assertFalse(any(isinstance(e, IncorrectQuantityError) for e in results)) - def test_missing_fixed_quantity_should_not_return_incorrect_quantity_error(self): + def test_missing_fixed_quantity_should_not_return_incorrect_quantity_error( + self, + ): """Verify that missing groups/datasets/links should not return an IncorrectQuantityError""" self.configure_specs(self.create_test_specs(q_groups=5, q_datasets=3, q_links=2)) builder = self.get_test_builder(0, 0, 0) results = self.vmap.validate(builder) self.assertFalse(any(isinstance(e, IncorrectQuantityError) for e in results)) - def test_incorrect_fixed_quantity_should_return_incorrect_quantity_error(self): + def test_incorrect_fixed_quantity_should_return_incorrect_quantity_error( + self, + ): """Verify that an incorrect quantity of groups/datasets/links should return an IncorrectQuantityError""" self.configure_specs(self.create_test_specs(q_groups=5, q_datasets=5, q_links=5)) for n in [1, 2, 10]: @@ -328,7 +451,9 @@ def test_incorrect_fixed_quantity_should_return_incorrect_quantity_error(self): self.assertEqual(len(results), 3) self.assertTrue(all(isinstance(e, IncorrectQuantityError) for e in results)) - def test_incorrect_zero_or_one_quantity_should_return_incorrect_quantity_error(self): + def test_incorrect_zero_or_one_quantity_should_return_incorrect_quantity_error( + self, + ): """Verify that an incorrect ZERO_OR_ONE quantity of groups/datasets/links should return an IncorrectQuantityError """ @@ -348,148 +473,192 @@ def test_incorrect_quantity_error_message(self): self.assertEqual(len(results), 1) self.assertIsInstance(results[0], IncorrectQuantityError) message = str(results[0]) - self.assertTrue('expected a quantity of 2' in message) - self.assertTrue('received 7' in message) + self.assertTrue("expected a quantity of 2" in message) + self.assertTrue("received 7" in message) class TestDtypeValidation(TestCase): - def set_up_spec(self, dtype): spec_catalog = SpecCatalog() - spec = GroupSpec('A test group specification with a data type', - data_type_def='Bar', - datasets=[DatasetSpec('an example dataset', dtype, name='data')], - attributes=[AttributeSpec('attr1', 'an example attribute', dtype)]) - spec_catalog.register_spec(spec, 'test.yaml') + spec = GroupSpec( + "A test group specification with a data type", + data_type_def="Bar", + datasets=[DatasetSpec("an example dataset", dtype, name="data")], + attributes=[AttributeSpec("attr1", "an example attribute", dtype)], + ) + spec_catalog.register_spec(spec, "test.yaml") self.namespace = SpecNamespace( - 'a test namespace', CORE_NAMESPACE, [{'source': 'test.yaml'}], version='0.1.0', catalog=spec_catalog) + "a test namespace", + CORE_NAMESPACE, + [{"source": "test.yaml"}], + version="0.1.0", + catalog=spec_catalog, + ) self.vmap = ValidatorMap(self.namespace) def test_ascii_for_utf8(self): """Test that validator allows ASCII data where UTF8 is specified.""" - self.set_up_spec('text') - value = b'an ascii string' - bar_builder = GroupBuilder('my_bar', - attributes={'data_type': 'Bar', 'attr1': value}, - datasets=[DatasetBuilder('data', value)]) + self.set_up_spec("text") + value = b"an ascii string" + bar_builder = GroupBuilder( + "my_bar", + attributes={"data_type": "Bar", "attr1": value}, + datasets=[DatasetBuilder("data", value)], + ) results = self.vmap.validate(bar_builder) self.assertEqual(len(results), 0) def test_utf8_for_ascii(self): """Test that validator does not allow UTF8 where ASCII is specified.""" - self.set_up_spec('bytes') - value = 'a utf8 string' - bar_builder = GroupBuilder('my_bar', - attributes={'data_type': 'Bar', 'attr1': value}, - datasets=[DatasetBuilder('data', value)]) + self.set_up_spec("bytes") + value = "a utf8 string" + bar_builder = GroupBuilder( + "my_bar", + attributes={"data_type": "Bar", "attr1": value}, + datasets=[DatasetBuilder("data", value)], + ) results = self.vmap.validate(bar_builder) result_strings = set([str(s) for s in results]) - expected_errors = {"Bar/attr1 (my_bar.attr1): incorrect type - expected 'bytes', got 'utf'", - "Bar/data (my_bar/data): incorrect type - expected 'bytes', got 'utf'"} + expected_errors = { + "Bar/attr1 (my_bar.attr1): incorrect type - expected 'bytes', got 'utf'", + "Bar/data (my_bar/data): incorrect type - expected 'bytes', got 'utf'", + } self.assertEqual(result_strings, expected_errors) def test_int64_for_int8(self): """Test that validator allows int64 data where int8 is specified.""" - self.set_up_spec('int8') + self.set_up_spec("int8") value = np.int64(1) - bar_builder = GroupBuilder('my_bar', - attributes={'data_type': 'Bar', 'attr1': value}, - datasets=[DatasetBuilder('data', value)]) + bar_builder = GroupBuilder( + "my_bar", + attributes={"data_type": "Bar", "attr1": value}, + datasets=[DatasetBuilder("data", value)], + ) results = self.vmap.validate(bar_builder) self.assertEqual(len(results), 0) def test_int8_for_int64(self): """Test that validator does not allow int8 data where int64 is specified.""" - self.set_up_spec('int64') + self.set_up_spec("int64") value = np.int8(1) - bar_builder = GroupBuilder('my_bar', - attributes={'data_type': 'Bar', 'attr1': value}, - datasets=[DatasetBuilder('data', value)]) + bar_builder = GroupBuilder( + "my_bar", + attributes={"data_type": "Bar", "attr1": value}, + datasets=[DatasetBuilder("data", value)], + ) results = self.vmap.validate(bar_builder) result_strings = set([str(s) for s in results]) - expected_errors = {"Bar/attr1 (my_bar.attr1): incorrect type - expected 'int64', got 'int8'", - "Bar/data (my_bar/data): incorrect type - expected 'int64', got 'int8'"} + expected_errors = { + "Bar/attr1 (my_bar.attr1): incorrect type - expected 'int64', got 'int8'", + "Bar/data (my_bar/data): incorrect type - expected 'int64', got 'int8'", + } self.assertEqual(result_strings, expected_errors) def test_int64_for_numeric(self): """Test that validator allows int64 data where numeric is specified.""" - self.set_up_spec('numeric') + self.set_up_spec("numeric") value = np.int64(1) - bar_builder = GroupBuilder('my_bar', - attributes={'data_type': 'Bar', 'attr1': value}, - datasets=[DatasetBuilder('data', value)]) + bar_builder = GroupBuilder( + "my_bar", + attributes={"data_type": "Bar", "attr1": value}, + datasets=[DatasetBuilder("data", value)], + ) results = self.vmap.validate(bar_builder) self.assertEqual(len(results), 0) def test_bool_for_numeric(self): """Test that validator does not allow bool data where numeric is specified.""" - self.set_up_spec('numeric') + self.set_up_spec("numeric") value = True - bar_builder = GroupBuilder('my_bar', - attributes={'data_type': 'Bar', 'attr1': value}, - datasets=[DatasetBuilder('data', value)]) + bar_builder = GroupBuilder( + "my_bar", + attributes={"data_type": "Bar", "attr1": value}, + datasets=[DatasetBuilder("data", value)], + ) results = self.vmap.validate(bar_builder) result_strings = set([str(s) for s in results]) - expected_errors = {"Bar/attr1 (my_bar.attr1): incorrect type - expected 'numeric', got 'bool'", - "Bar/data (my_bar/data): incorrect type - expected 'numeric', got 'bool'"} + expected_errors = { + "Bar/attr1 (my_bar.attr1): incorrect type - expected 'numeric', got 'bool'", + "Bar/data (my_bar/data): incorrect type - expected 'numeric', got 'bool'", + } self.assertEqual(result_strings, expected_errors) def test_np_bool_for_bool(self): """Test that validator allows np.bool_ data where bool is specified.""" - self.set_up_spec('bool') + self.set_up_spec("bool") value = np.bool_(True) - bar_builder = GroupBuilder('my_bar', - attributes={'data_type': 'Bar', 'attr1': value}, - datasets=[DatasetBuilder('data', value)]) + bar_builder = GroupBuilder( + "my_bar", + attributes={"data_type": "Bar", "attr1": value}, + datasets=[DatasetBuilder("data", value)], + ) results = self.vmap.validate(bar_builder) self.assertEqual(len(results), 0) class Test1DArrayValidation(TestCase): - def set_up_spec(self, dtype): spec_catalog = SpecCatalog() - spec = GroupSpec('A test group specification with a data type', - data_type_def='Bar', - datasets=[DatasetSpec('an example dataset', dtype, name='data', shape=(None, ))], - attributes=[AttributeSpec('attr1', 'an example attribute', dtype, shape=(None, ))]) - spec_catalog.register_spec(spec, 'test.yaml') + spec = GroupSpec( + "A test group specification with a data type", + data_type_def="Bar", + datasets=[DatasetSpec("an example dataset", dtype, name="data", shape=(None,))], + attributes=[AttributeSpec("attr1", "an example attribute", dtype, shape=(None,))], + ) + spec_catalog.register_spec(spec, "test.yaml") self.namespace = SpecNamespace( - 'a test namespace', CORE_NAMESPACE, [{'source': 'test.yaml'}], version='0.1.0', catalog=spec_catalog) + "a test namespace", + CORE_NAMESPACE, + [{"source": "test.yaml"}], + version="0.1.0", + catalog=spec_catalog, + ) self.vmap = ValidatorMap(self.namespace) def test_scalar(self): """Test that validator does not allow a scalar where an array is specified.""" - self.set_up_spec('text') - value = 'a string' - bar_builder = GroupBuilder('my_bar', - attributes={'data_type': 'Bar', 'attr1': value}, - datasets=[DatasetBuilder('data', value)]) + self.set_up_spec("text") + value = "a string" + bar_builder = GroupBuilder( + "my_bar", + attributes={"data_type": "Bar", "attr1": value}, + datasets=[DatasetBuilder("data", value)], + ) results = self.vmap.validate(bar_builder) result_strings = set([str(s) for s in results]) - expected_errors = {("Bar/attr1 (my_bar.attr1): incorrect shape - expected an array of shape '(None,)', " - "got non-array data 'a string'"), - ("Bar/data (my_bar/data): incorrect shape - expected an array of shape '(None,)', " - "got non-array data 'a string'")} + expected_errors = { + ( + "Bar/attr1 (my_bar.attr1): incorrect shape - expected an array of shape" + " '(None,)', got non-array data 'a string'" + ), + ( + "Bar/data (my_bar/data): incorrect shape - expected an array of shape" + " '(None,)', got non-array data 'a string'" + ), + } self.assertEqual(result_strings, expected_errors) def test_empty_list(self): """Test that validator allows an empty list where an array is specified.""" - self.set_up_spec('text') + self.set_up_spec("text") value = [] - bar_builder = GroupBuilder('my_bar', - attributes={'data_type': 'Bar', 'attr1': value}, - datasets=[DatasetBuilder('data', value)]) + bar_builder = GroupBuilder( + "my_bar", + attributes={"data_type": "Bar", "attr1": value}, + datasets=[DatasetBuilder("data", value)], + ) results = self.vmap.validate(bar_builder) self.assertEqual(len(results), 0) def test_empty_nparray(self): """Test that validator allows an empty numpy array where an array is specified.""" - self.set_up_spec('text') + self.set_up_spec("text") value = np.array([]) # note: dtype is float64 - bar_builder = GroupBuilder('my_bar', - attributes={'data_type': 'Bar', 'attr1': value}, - datasets=[DatasetBuilder('data', value)]) + bar_builder = GroupBuilder( + "my_bar", + attributes={"data_type": "Bar", "attr1": value}, + datasets=[DatasetBuilder("data", value)], + ) results = self.vmap.validate(bar_builder) self.assertEqual(len(results), 0) @@ -497,44 +666,86 @@ def test_empty_nparray(self): class TestLinkable(TestCase): - def set_up_spec(self): spec_catalog = SpecCatalog() - typed_dataset_spec = DatasetSpec('A typed dataset', data_type_def='Foo') - typed_group_spec = GroupSpec('A typed group', data_type_def='Bar') - spec = GroupSpec('A test group specification with a data type', - data_type_def='Baz', - datasets=[ - DatasetSpec('A linkable child dataset', name='untyped_linkable_ds', - linkable=True, quantity=ZERO_OR_ONE), - DatasetSpec('A non-linkable child dataset', name='untyped_nonlinkable_ds', - linkable=False, quantity=ZERO_OR_ONE), - DatasetSpec('A linkable child dataset', data_type_inc='Foo', - name='typed_linkable_ds', linkable=True, quantity=ZERO_OR_ONE), - DatasetSpec('A non-linkable child dataset', data_type_inc='Foo', - name='typed_nonlinkable_ds', linkable=False, quantity=ZERO_OR_ONE), - ], - groups=[ - GroupSpec('A linkable child group', name='untyped_linkable_group', - linkable=True, quantity=ZERO_OR_ONE), - GroupSpec('A non-linkable child group', name='untyped_nonlinkable_group', - linkable=False, quantity=ZERO_OR_ONE), - GroupSpec('A linkable child group', data_type_inc='Bar', - name='typed_linkable_group', linkable=True, quantity=ZERO_OR_ONE), - GroupSpec('A non-linkable child group', data_type_inc='Bar', - name='typed_nonlinkable_group', linkable=False, quantity=ZERO_OR_ONE), - ]) - spec_catalog.register_spec(spec, 'test.yaml') - spec_catalog.register_spec(typed_dataset_spec, 'test.yaml') - spec_catalog.register_spec(typed_group_spec, 'test.yaml') + typed_dataset_spec = DatasetSpec("A typed dataset", data_type_def="Foo") + typed_group_spec = GroupSpec("A typed group", data_type_def="Bar") + spec = GroupSpec( + "A test group specification with a data type", + data_type_def="Baz", + datasets=[ + DatasetSpec( + "A linkable child dataset", + name="untyped_linkable_ds", + linkable=True, + quantity=ZERO_OR_ONE, + ), + DatasetSpec( + "A non-linkable child dataset", + name="untyped_nonlinkable_ds", + linkable=False, + quantity=ZERO_OR_ONE, + ), + DatasetSpec( + "A linkable child dataset", + data_type_inc="Foo", + name="typed_linkable_ds", + linkable=True, + quantity=ZERO_OR_ONE, + ), + DatasetSpec( + "A non-linkable child dataset", + data_type_inc="Foo", + name="typed_nonlinkable_ds", + linkable=False, + quantity=ZERO_OR_ONE, + ), + ], + groups=[ + GroupSpec( + "A linkable child group", + name="untyped_linkable_group", + linkable=True, + quantity=ZERO_OR_ONE, + ), + GroupSpec( + "A non-linkable child group", + name="untyped_nonlinkable_group", + linkable=False, + quantity=ZERO_OR_ONE, + ), + GroupSpec( + "A linkable child group", + data_type_inc="Bar", + name="typed_linkable_group", + linkable=True, + quantity=ZERO_OR_ONE, + ), + GroupSpec( + "A non-linkable child group", + data_type_inc="Bar", + name="typed_nonlinkable_group", + linkable=False, + quantity=ZERO_OR_ONE, + ), + ], + ) + spec_catalog.register_spec(spec, "test.yaml") + spec_catalog.register_spec(typed_dataset_spec, "test.yaml") + spec_catalog.register_spec(typed_group_spec, "test.yaml") self.namespace = SpecNamespace( - 'a test namespace', CORE_NAMESPACE, [{'source': 'test.yaml'}], version='0.1.0', catalog=spec_catalog) + "a test namespace", + CORE_NAMESPACE, + [{"source": "test.yaml"}], + version="0.1.0", + catalog=spec_catalog, + ) self.vmap = ValidatorMap(self.namespace) def validate_linkability(self, link, expect_error): """Execute a linkability test and assert whether or not an IllegalLinkError is returned""" self.set_up_spec() - builder = GroupBuilder('my_baz', attributes={'data_type': 'Baz'}, links=[link]) + builder = GroupBuilder("my_baz", attributes={"data_type": "Baz"}, links=[link]) result = self.vmap.validate(builder) if expect_error: self.assertEqual(len(result), 1) @@ -544,46 +755,54 @@ def validate_linkability(self, link, expect_error): def test_untyped_linkable_dataset_accepts_link(self): """Test that the validator accepts a link when the spec has an untyped linkable dataset""" - link = LinkBuilder(name='untyped_linkable_ds', builder=DatasetBuilder('foo')) + link = LinkBuilder(name="untyped_linkable_ds", builder=DatasetBuilder("foo")) self.validate_linkability(link, expect_error=False) def test_untyped_nonlinkable_dataset_does_not_accept_link(self): """Test that the validator returns an IllegalLinkError when the spec has an untyped non-linkable dataset""" - link = LinkBuilder(name='untyped_nonlinkable_ds', builder=DatasetBuilder('foo')) + link = LinkBuilder(name="untyped_nonlinkable_ds", builder=DatasetBuilder("foo")) self.validate_linkability(link, expect_error=True) def test_typed_linkable_dataset_accepts_link(self): """Test that the validator accepts a link when the spec has a typed linkable dataset""" - link = LinkBuilder(name='typed_linkable_ds', - builder=DatasetBuilder('foo', attributes={'data_type': 'Foo'})) + link = LinkBuilder( + name="typed_linkable_ds", + builder=DatasetBuilder("foo", attributes={"data_type": "Foo"}), + ) self.validate_linkability(link, expect_error=False) def test_typed_nonlinkable_dataset_does_not_accept_link(self): """Test that the validator returns an IllegalLinkError when the spec has a typed non-linkable dataset""" - link = LinkBuilder(name='typed_nonlinkable_ds', - builder=DatasetBuilder('foo', attributes={'data_type': 'Foo'})) + link = LinkBuilder( + name="typed_nonlinkable_ds", + builder=DatasetBuilder("foo", attributes={"data_type": "Foo"}), + ) self.validate_linkability(link, expect_error=True) def test_untyped_linkable_group_accepts_link(self): """Test that the validator accepts a link when the spec has an untyped linkable group""" - link = LinkBuilder(name='untyped_linkable_group', builder=GroupBuilder('foo')) + link = LinkBuilder(name="untyped_linkable_group", builder=GroupBuilder("foo")) self.validate_linkability(link, expect_error=False) def test_untyped_nonlinkable_group_does_not_accept_link(self): """Test that the validator returns an IllegalLinkError when the spec has an untyped non-linkable group""" - link = LinkBuilder(name='untyped_nonlinkable_group', builder=GroupBuilder('foo')) + link = LinkBuilder(name="untyped_nonlinkable_group", builder=GroupBuilder("foo")) self.validate_linkability(link, expect_error=True) def test_typed_linkable_group_accepts_link(self): """Test that the validator accepts a link when the spec has a typed linkable group""" - link = LinkBuilder(name='typed_linkable_group', - builder=GroupBuilder('foo', attributes={'data_type': 'Bar'})) + link = LinkBuilder( + name="typed_linkable_group", + builder=GroupBuilder("foo", attributes={"data_type": "Bar"}), + ) self.validate_linkability(link, expect_error=False) def test_typed_nonlinkable_group_does_not_accept_link(self): """Test that the validator returns an IllegalLinkError when the spec has a typed non-linkable group""" - link = LinkBuilder(name='typed_nonlinkable_group', - builder=GroupBuilder('foo', attributes={'data_type': 'Bar'})) + link = LinkBuilder( + name="typed_nonlinkable_group", + builder=GroupBuilder("foo", attributes={"data_type": "Bar"}), + ) self.validate_linkability(link, expect_error=True) @mock.patch("hdmf.validate.validator.DatasetValidator.validate") @@ -594,10 +813,16 @@ def test_should_not_validate_illegally_linked_objects(self, mock_validator): https://github.com/hdmf-dev/hdmf/issues/516 """ self.set_up_spec() - typed_link = LinkBuilder(name='typed_nonlinkable_ds', - builder=DatasetBuilder('foo', attributes={'data_type': 'Foo'})) - untyped_link = LinkBuilder(name='untyped_nonlinkable_ds', builder=DatasetBuilder('foo')) - builder = GroupBuilder('my_baz', attributes={'data_type': 'Baz'}, links=[typed_link, untyped_link]) + typed_link = LinkBuilder( + name="typed_nonlinkable_ds", + builder=DatasetBuilder("foo", attributes={"data_type": "Foo"}), + ) + untyped_link = LinkBuilder(name="untyped_nonlinkable_ds", builder=DatasetBuilder("foo")) + builder = GroupBuilder( + "my_baz", + attributes={"data_type": "Baz"}, + links=[typed_link, untyped_link], + ) _ = self.vmap.validate(builder) assert not mock_validator.called @@ -609,61 +834,72 @@ class TestMultipleNamedChildrenOfSameType(TestCase): def set_up_spec(self): spec_catalog = SpecCatalog() - dataset_spec = DatasetSpec('A dataset', data_type_def='Foo') - group_spec = GroupSpec('A group', data_type_def='Bar') - spec = GroupSpec('A test group specification with a data type', - data_type_def='Baz', - datasets=[ - DatasetSpec('Child Dataset A', name='a', data_type_inc='Foo'), - DatasetSpec('Child Dataset B', name='b', data_type_inc='Foo'), - ], - groups=[ - GroupSpec('Child Group X', name='x', data_type_inc='Bar'), - GroupSpec('Child Group Y', name='y', data_type_inc='Bar'), - ]) - spec_catalog.register_spec(spec, 'test.yaml') - spec_catalog.register_spec(dataset_spec, 'test.yaml') - spec_catalog.register_spec(group_spec, 'test.yaml') + dataset_spec = DatasetSpec("A dataset", data_type_def="Foo") + group_spec = GroupSpec("A group", data_type_def="Bar") + spec = GroupSpec( + "A test group specification with a data type", + data_type_def="Baz", + datasets=[ + DatasetSpec("Child Dataset A", name="a", data_type_inc="Foo"), + DatasetSpec("Child Dataset B", name="b", data_type_inc="Foo"), + ], + groups=[ + GroupSpec("Child Group X", name="x", data_type_inc="Bar"), + GroupSpec("Child Group Y", name="y", data_type_inc="Bar"), + ], + ) + spec_catalog.register_spec(spec, "test.yaml") + spec_catalog.register_spec(dataset_spec, "test.yaml") + spec_catalog.register_spec(group_spec, "test.yaml") self.namespace = SpecNamespace( - 'a test namespace', CORE_NAMESPACE, [{'source': 'test.yaml'}], version='0.1.0', catalog=spec_catalog) + "a test namespace", + CORE_NAMESPACE, + [{"source": "test.yaml"}], + version="0.1.0", + catalog=spec_catalog, + ) self.vmap = ValidatorMap(self.namespace) def validate_multiple_children(self, dataset_names, group_names): """Utility function to validate a builder with the specified named dataset and group children""" self.set_up_spec() - datasets = [DatasetBuilder(ds, attributes={'data_type': 'Foo'}) for ds in dataset_names] - groups = [GroupBuilder(gr, attributes={'data_type': 'Bar'}) for gr in group_names] - builder = GroupBuilder('my_baz', attributes={'data_type': 'Baz'}, - datasets=datasets, groups=groups) + datasets = [DatasetBuilder(ds, attributes={"data_type": "Foo"}) for ds in dataset_names] + groups = [GroupBuilder(gr, attributes={"data_type": "Bar"}) for gr in group_names] + builder = GroupBuilder( + "my_baz", + attributes={"data_type": "Baz"}, + datasets=datasets, + groups=groups, + ) return self.vmap.validate(builder) def test_missing_first_dataset_should_return_error(self): """Test that the validator returns a MissingDataType error if the first dataset is missing""" - result = self.validate_multiple_children(['b'], ['x', 'y']) + result = self.validate_multiple_children(["b"], ["x", "y"]) self.assertEqual(len(result), 1) self.assertIsInstance(result[0], MissingDataType) def test_missing_last_dataset_should_return_error(self): """Test that the validator returns a MissingDataType error if the last dataset is missing""" - result = self.validate_multiple_children(['a'], ['x', 'y']) + result = self.validate_multiple_children(["a"], ["x", "y"]) self.assertEqual(len(result), 1) self.assertIsInstance(result[0], MissingDataType) def test_missing_first_group_should_return_error(self): """Test that the validator returns a MissingDataType error if the first group is missing""" - result = self.validate_multiple_children(['a', 'b'], ['y']) + result = self.validate_multiple_children(["a", "b"], ["y"]) self.assertEqual(len(result), 1) self.assertIsInstance(result[0], MissingDataType) def test_missing_last_group_should_return_error(self): """Test that the validator returns a MissingDataType error if the last group is missing""" - result = self.validate_multiple_children(['a', 'b'], ['x']) + result = self.validate_multiple_children(["a", "b"], ["x"]) self.assertEqual(len(result), 1) self.assertIsInstance(result[0], MissingDataType) def test_no_errors_when_all_children_satisfied(self): """Test that the validator does not return an error if all child specs are satisfied""" - result = self.validate_multiple_children(['a', 'b'], ['x', 'y']) + result = self.validate_multiple_children(["a", "b"], ["x", "y"]) self.assertEqual(len(result), 0) @@ -674,25 +910,32 @@ class TestLinkAndChildMatchingDataType(TestCase): def set_up_spec(self): spec_catalog = SpecCatalog() - dataset_spec = DatasetSpec('A dataset', data_type_def='Foo') - group_spec = GroupSpec('A group', data_type_def='Bar') - spec = GroupSpec('A test group specification with a data type', - data_type_def='Baz', - datasets=[ - DatasetSpec('Child Dataset', name='dataset', data_type_inc='Foo'), - ], - groups=[ - GroupSpec('Child Group', name='group', data_type_inc='Bar'), - ], - links=[ - LinkSpec('Linked Dataset', name='dataset_link', target_type='Foo'), - LinkSpec('Linked Dataset', name='group_link', target_type='Bar') - ]) - spec_catalog.register_spec(spec, 'test.yaml') - spec_catalog.register_spec(dataset_spec, 'test.yaml') - spec_catalog.register_spec(group_spec, 'test.yaml') + dataset_spec = DatasetSpec("A dataset", data_type_def="Foo") + group_spec = GroupSpec("A group", data_type_def="Bar") + spec = GroupSpec( + "A test group specification with a data type", + data_type_def="Baz", + datasets=[ + DatasetSpec("Child Dataset", name="dataset", data_type_inc="Foo"), + ], + groups=[ + GroupSpec("Child Group", name="group", data_type_inc="Bar"), + ], + links=[ + LinkSpec("Linked Dataset", name="dataset_link", target_type="Foo"), + LinkSpec("Linked Dataset", name="group_link", target_type="Bar"), + ], + ) + spec_catalog.register_spec(spec, "test.yaml") + spec_catalog.register_spec(dataset_spec, "test.yaml") + spec_catalog.register_spec(group_spec, "test.yaml") self.namespace = SpecNamespace( - 'a test namespace', CORE_NAMESPACE, [{'source': 'test.yaml'}], version='0.1.0', catalog=spec_catalog) + "a test namespace", + CORE_NAMESPACE, + [{"source": "test.yaml"}], + version="0.1.0", + catalog=spec_catalog, + ) self.vmap = ValidatorMap(self.namespace) def validate_matching_link_data_type_case(self, datasets, groups, links): @@ -700,8 +943,13 @@ def validate_matching_link_data_type_case(self, datasets, groups, links): children and verify that a MissingDataType error is returned """ self.set_up_spec() - builder = GroupBuilder('my_baz', attributes={'data_type': 'Baz'}, - datasets=datasets, groups=groups, links=links) + builder = GroupBuilder( + "my_baz", + attributes={"data_type": "Baz"}, + datasets=datasets, + groups=groups, + links=links, + ) result = self.vmap.validate(builder) self.assertEqual(len(result), 1) self.assertIsInstance(result[0], MissingDataType) @@ -709,40 +957,58 @@ def validate_matching_link_data_type_case(self, datasets, groups, links): def test_error_on_missing_child_dataset(self): """Test that a MissingDataType is returned when the child dataset is missing""" datasets = [] - groups = [GroupBuilder('group', attributes={'data_type': 'Bar'})] + groups = [GroupBuilder("group", attributes={"data_type": "Bar"})] links = [ - LinkBuilder(name='dataset_link', builder=DatasetBuilder('foo', attributes={'data_type': 'Foo'})), - LinkBuilder(name='group_link', builder=GroupBuilder('bar', attributes={'data_type': 'Bar'})) + LinkBuilder( + name="dataset_link", + builder=DatasetBuilder("foo", attributes={"data_type": "Foo"}), + ), + LinkBuilder( + name="group_link", + builder=GroupBuilder("bar", attributes={"data_type": "Bar"}), + ), ] self.validate_matching_link_data_type_case(datasets, groups, links) def test_error_on_missing_linked_dataset(self): """Test that a MissingDataType is returned when the linked dataset is missing""" - datasets = [DatasetBuilder('dataset', attributes={'data_type': 'Foo'})] - groups = [GroupBuilder('group', attributes={'data_type': 'Bar'})] + datasets = [DatasetBuilder("dataset", attributes={"data_type": "Foo"})] + groups = [GroupBuilder("group", attributes={"data_type": "Bar"})] links = [ - LinkBuilder(name='group_link', builder=GroupBuilder('bar', attributes={'data_type': 'Bar'})) + LinkBuilder( + name="group_link", + builder=GroupBuilder("bar", attributes={"data_type": "Bar"}), + ) ] self.validate_matching_link_data_type_case(datasets, groups, links) def test_error_on_missing_group(self): """Test that a MissingDataType is returned when the child group is missing""" self.set_up_spec() - datasets = [DatasetBuilder('dataset', attributes={'data_type': 'Foo'})] + datasets = [DatasetBuilder("dataset", attributes={"data_type": "Foo"})] groups = [] links = [ - LinkBuilder(name='dataset_link', builder=DatasetBuilder('foo', attributes={'data_type': 'Foo'})), - LinkBuilder(name='group_link', builder=GroupBuilder('bar', attributes={'data_type': 'Bar'})) + LinkBuilder( + name="dataset_link", + builder=DatasetBuilder("foo", attributes={"data_type": "Foo"}), + ), + LinkBuilder( + name="group_link", + builder=GroupBuilder("bar", attributes={"data_type": "Bar"}), + ), ] self.validate_matching_link_data_type_case(datasets, groups, links) def test_error_on_missing_linked_group(self): """Test that a MissingDataType is returned when the linked group is missing""" self.set_up_spec() - datasets = [DatasetBuilder('dataset', attributes={'data_type': 'Foo'})] - groups = [GroupBuilder('group', attributes={'data_type': 'Bar'})] + datasets = [DatasetBuilder("dataset", attributes={"data_type": "Foo"})] + groups = [GroupBuilder("group", attributes={"data_type": "Bar"})] links = [ - LinkBuilder(name='dataset_link', builder=DatasetBuilder('foo', attributes={'data_type': 'Foo'})) + LinkBuilder( + name="dataset_link", + builder=DatasetBuilder("foo", attributes={"data_type": "Foo"}), + ) ] self.validate_matching_link_data_type_case(datasets, groups, links) @@ -754,20 +1020,26 @@ class TestMultipleChildrenAtDifferentLevelsOfInheritance(TestCase): def set_up_spec(self): spec_catalog = SpecCatalog() - dataset_spec = DatasetSpec('A dataset', data_type_def='Foo') - sub_dataset_spec = DatasetSpec('An Inheriting Dataset', - data_type_def='Bar', data_type_inc='Foo') - spec = GroupSpec('A test group specification with a data type', - data_type_def='Baz', - datasets=[ - DatasetSpec('Child Dataset', data_type_inc='Foo'), - DatasetSpec('Child Dataset', data_type_inc='Bar'), - ]) - spec_catalog.register_spec(spec, 'test.yaml') - spec_catalog.register_spec(dataset_spec, 'test.yaml') - spec_catalog.register_spec(sub_dataset_spec, 'test.yaml') + dataset_spec = DatasetSpec("A dataset", data_type_def="Foo") + sub_dataset_spec = DatasetSpec("An Inheriting Dataset", data_type_def="Bar", data_type_inc="Foo") + spec = GroupSpec( + "A test group specification with a data type", + data_type_def="Baz", + datasets=[ + DatasetSpec("Child Dataset", data_type_inc="Foo"), + DatasetSpec("Child Dataset", data_type_inc="Bar"), + ], + ) + spec_catalog.register_spec(spec, "test.yaml") + spec_catalog.register_spec(dataset_spec, "test.yaml") + spec_catalog.register_spec(sub_dataset_spec, "test.yaml") self.namespace = SpecNamespace( - 'a test namespace', CORE_NAMESPACE, [{'source': 'test.yaml'}], version='0.1.0', catalog=spec_catalog) + "a test namespace", + CORE_NAMESPACE, + [{"source": "test.yaml"}], + version="0.1.0", + catalog=spec_catalog, + ) self.vmap = ValidatorMap(self.namespace) def test_error_returned_when_child_at_highest_level_missing(self): @@ -775,10 +1047,8 @@ def test_error_returned_when_child_at_highest_level_missing(self): the highest level of the inheritance hierarchy is missing """ self.set_up_spec() - datasets = [ - DatasetBuilder('bar', attributes={'data_type': 'Bar'}) - ] - builder = GroupBuilder('my_baz', attributes={'data_type': 'Baz'}, datasets=datasets) + datasets = [DatasetBuilder("bar", attributes={"data_type": "Bar"})] + builder = GroupBuilder("my_baz", attributes={"data_type": "Baz"}, datasets=datasets) result = self.vmap.validate(builder) self.assertEqual(len(result), 1) self.assertIsInstance(result[0], MissingDataType) @@ -788,10 +1058,8 @@ def test_error_returned_when_child_at_lowest_level_missing(self): the lowest level of the inheritance hierarchy is missing """ self.set_up_spec() - datasets = [ - DatasetBuilder('foo', attributes={'data_type': 'Foo'}) - ] - builder = GroupBuilder('my_baz', attributes={'data_type': 'Baz'}, datasets=datasets) + datasets = [DatasetBuilder("foo", attributes={"data_type": "Foo"})] + builder = GroupBuilder("my_baz", attributes={"data_type": "Baz"}, datasets=datasets) result = self.vmap.validate(builder) self.assertEqual(len(result), 1) self.assertIsInstance(result[0], MissingDataType) @@ -802,10 +1070,10 @@ def test_both_levels_of_hierarchy_validated(self): """ self.set_up_spec() datasets = [ - DatasetBuilder('foo', attributes={'data_type': 'Foo'}), - DatasetBuilder('bar', attributes={'data_type': 'Bar'}) + DatasetBuilder("foo", attributes={"data_type": "Foo"}), + DatasetBuilder("bar", attributes={"data_type": "Bar"}), ] - builder = GroupBuilder('my_baz', attributes={'data_type': 'Baz'}, datasets=datasets) + builder = GroupBuilder("my_baz", attributes={"data_type": "Baz"}, datasets=datasets) result = self.vmap.validate(builder) self.assertEqual(len(result), 0) @@ -817,10 +1085,10 @@ def test_both_levels_of_hierarchy_validated_inverted_order(self): """ self.set_up_spec() datasets = [ - DatasetBuilder('bar', attributes={'data_type': 'Bar'}), - DatasetBuilder('foo', attributes={'data_type': 'Foo'}) + DatasetBuilder("bar", attributes={"data_type": "Bar"}), + DatasetBuilder("foo", attributes={"data_type": "Foo"}), ] - builder = GroupBuilder('my_baz', attributes={'data_type': 'Baz'}, datasets=datasets) + builder = GroupBuilder("my_baz", attributes={"data_type": "Baz"}, datasets=datasets) result = self.vmap.validate(builder) self.assertEqual(len(result), 0) @@ -848,32 +1116,51 @@ class TestExtendedIncDataTypes(TestCase): def setup_spec(self): """Prepare a set of specs for tests which includes an anonymous data type extension""" spec_catalog = SpecCatalog() - attr_foo = AttributeSpec(name='foo', doc='an attribute', dtype='text') - attr_bar = AttributeSpec(name='bar', doc='an attribute', dtype='numeric') - d1_spec = DatasetSpec(doc='type D1', data_type_def='D1', dtype='numeric', - attributes=[attr_foo]) - d2_spec = DatasetSpec(doc='type D2', data_type_def='D2', data_type_inc=d1_spec) - g1_spec = GroupSpec(doc='type G1', data_type_def='G1', - datasets=[DatasetSpec(doc='D1 extension', data_type_inc=d1_spec, - attributes=[attr_foo, attr_bar])]) + attr_foo = AttributeSpec(name="foo", doc="an attribute", dtype="text") + attr_bar = AttributeSpec(name="bar", doc="an attribute", dtype="numeric") + d1_spec = DatasetSpec( + doc="type D1", + data_type_def="D1", + dtype="numeric", + attributes=[attr_foo], + ) + d2_spec = DatasetSpec(doc="type D2", data_type_def="D2", data_type_inc=d1_spec) + g1_spec = GroupSpec( + doc="type G1", + data_type_def="G1", + datasets=[ + DatasetSpec( + doc="D1 extension", + data_type_inc=d1_spec, + attributes=[attr_foo, attr_bar], + ) + ], + ) for spec in [d1_spec, d2_spec, g1_spec]: - spec_catalog.register_spec(spec, 'test.yaml') - self.namespace = SpecNamespace('a test namespace', CORE_NAMESPACE, - [{'source': 'test.yaml'}], version='0.1.0', catalog=spec_catalog) + spec_catalog.register_spec(spec, "test.yaml") + self.namespace = SpecNamespace( + "a test namespace", + CORE_NAMESPACE, + [{"source": "test.yaml"}], + version="0.1.0", + catalog=spec_catalog, + ) self.vmap = ValidatorMap(self.namespace) - def test_missing_additional_attribute_on_anonymous_data_type_extension(self): + def test_missing_additional_attribute_on_anonymous_data_type_extension( + self, + ): """Verify that a MissingError is returned when a required attribute from an anonymous extension is not present """ self.setup_spec() - dataset = DatasetBuilder('test_d1', 42.0, attributes={'data_type': 'D1', 'foo': 'xyz'}) - builder = GroupBuilder('test_g1', attributes={'data_type': 'G1'}, datasets=[dataset]) + dataset = DatasetBuilder("test_d1", 42.0, attributes={"data_type": "D1", "foo": "xyz"}) + builder = GroupBuilder("test_g1", attributes={"data_type": "G1"}, datasets=[dataset]) result = self.vmap.validate(builder) self.assertEqual(len(result), 1) error = result[0] self.assertIsInstance(error, MissingError) - self.assertTrue('G1/D1/bar' in str(error)) + self.assertTrue("G1/D1/bar" in str(error)) def test_validate_child_type_against_anonymous_data_type_extension(self): """Verify that a MissingError is returned when a required attribute from an @@ -881,21 +1168,21 @@ def test_validate_child_type_against_anonymous_data_type_extension(self): type included in the anonymous extension. """ self.setup_spec() - dataset = DatasetBuilder('test_d2', 42.0, attributes={'data_type': 'D2', 'foo': 'xyz'}) - builder = GroupBuilder('test_g1', attributes={'data_type': 'G1'}, datasets=[dataset]) + dataset = DatasetBuilder("test_d2", 42.0, attributes={"data_type": "D2", "foo": "xyz"}) + builder = GroupBuilder("test_g1", attributes={"data_type": "G1"}, datasets=[dataset]) result = self.vmap.validate(builder) self.assertEqual(len(result), 1) error = result[0] self.assertIsInstance(error, MissingError) - self.assertTrue('G1/D1/bar' in str(error)) + self.assertTrue("G1/D1/bar" in str(error)) def test_redundant_attribute_in_spec(self): """Test that only one MissingError is returned when an attribute is missing which is redundantly defined in both a base data type and an inner data type """ self.setup_spec() - dataset = DatasetBuilder('test_d2', 42.0, attributes={'data_type': 'D2', 'bar': 5}) - builder = GroupBuilder('test_g1', attributes={'data_type': 'G1'}, datasets=[dataset]) + dataset = DatasetBuilder("test_d2", 42.0, attributes={"data_type": "D2", "bar": 5}) + builder = GroupBuilder("test_g1", attributes={"data_type": "G1"}, datasets=[dataset]) result = self.vmap.validate(builder) self.assertEqual(len(result), 1) @@ -911,7 +1198,7 @@ class TestReferenceDatasetsRoundTrip(ValidatorTestBase): """ def setUp(self): - self.filename = 'test_ref_dataset.h5' + self.filename = "test_ref_dataset.h5" super().setUp() def tearDown(self): @@ -920,34 +1207,46 @@ def tearDown(self): def getSpecs(self): qux_spec = DatasetSpec( - doc='a simple scalar dataset', - data_type_def='Qux', - dtype='int', - shape=None + doc="a simple scalar dataset", + data_type_def="Qux", + dtype="int", + shape=None, ) baz_spec = DatasetSpec( - doc='a dataset with a compound datatype that includes a reference', - data_type_def='Baz', + doc="a dataset with a compound datatype that includes a reference", + data_type_def="Baz", dtype=[ - DtypeSpec('x', doc='x-value', dtype='int'), - DtypeSpec('y', doc='y-ref', dtype=RefSpec('Qux', reftype='object')) + DtypeSpec("x", doc="x-value", dtype="int"), + DtypeSpec("y", doc="y-ref", dtype=RefSpec("Qux", reftype="object")), ], - shape=None + shape=None, ) bar_spec = DatasetSpec( - doc='a dataset of an array of references', - dtype=RefSpec('Qux', reftype='object'), - data_type_def='Bar', - shape=(None,) + doc="a dataset of an array of references", + dtype=RefSpec("Qux", reftype="object"), + data_type_def="Bar", + shape=(None,), ) foo_spec = GroupSpec( - doc='a base group for containing test datasets', - data_type_def='Foo', + doc="a base group for containing test datasets", + data_type_def="Foo", datasets=[ - DatasetSpec(doc='optional Bar', data_type_inc=bar_spec, quantity=ZERO_OR_ONE), - DatasetSpec(doc='optional Baz', data_type_inc=baz_spec, quantity=ZERO_OR_ONE), - DatasetSpec(doc='multiple qux', data_type_inc=qux_spec, quantity=ONE_OR_MANY) - ] + DatasetSpec( + doc="optional Bar", + data_type_inc=bar_spec, + quantity=ZERO_OR_ONE, + ), + DatasetSpec( + doc="optional Baz", + data_type_inc=baz_spec, + quantity=ZERO_OR_ONE, + ), + DatasetSpec( + doc="multiple qux", + data_type_inc=qux_spec, + quantity=ONE_OR_MANY, + ), + ], ) return (foo_spec, bar_spec, baz_spec, qux_spec) @@ -963,10 +1262,10 @@ def runBuilderRoundTrip(self, builder): typemap = TypeMap(ns_catalog) self.manager = BuildManager(typemap) - with HDF5IO(self.filename, manager=self.manager, mode='w') as write_io: + with HDF5IO(self.filename, manager=self.manager, mode="w") as write_io: write_io.write_builder(builder) - with HDF5IO(self.filename, manager=self.manager, mode='r') as read_io: + with HDF5IO(self.filename, manager=self.manager, mode="r") as read_io: read_builder = read_io.read_builder() errors = self.vmap.validate(read_builder) self.assertEqual(len(errors), 0, errors) @@ -974,38 +1273,38 @@ def runBuilderRoundTrip(self, builder): def test_round_trip_validation_of_reference_dataset_array(self): """Verify that a dataset builder containing an array of references passes validation after a round trip""" - qux1 = DatasetBuilder('q1', 5, attributes={'data_type': 'Qux'}) - qux2 = DatasetBuilder('q2', 10, attributes={'data_type': 'Qux'}) + qux1 = DatasetBuilder("q1", 5, attributes={"data_type": "Qux"}) + qux2 = DatasetBuilder("q2", 10, attributes={"data_type": "Qux"}) bar = DatasetBuilder( - name='bar', + name="bar", data=[ReferenceBuilder(qux1), ReferenceBuilder(qux2)], - attributes={'data_type': 'Bar'}, - dtype='object' + attributes={"data_type": "Bar"}, + dtype="object", ) foo = GroupBuilder( - name='foo', + name="foo", datasets=[bar, qux1, qux2], - attributes={'data_type': 'Foo'} + attributes={"data_type": "Foo"}, ) self.runBuilderRoundTrip(foo) def test_round_trip_validation_of_compound_dtype_with_reference(self): """Verify that a dataset builder containing data with a compound dtype containing a reference passes validation after a round trip""" - qux1 = DatasetBuilder('q1', 5, attributes={'data_type': 'Qux'}) - qux2 = DatasetBuilder('q2', 10, attributes={'data_type': 'Qux'}) + qux1 = DatasetBuilder("q1", 5, attributes={"data_type": "Qux"}) + qux2 = DatasetBuilder("q2", 10, attributes={"data_type": "Qux"}) baz = DatasetBuilder( - name='baz', + name="baz", data=[(10, ReferenceBuilder(qux1))], dtype=[ - DtypeSpec('x', doc='x-value', dtype='int'), - DtypeSpec('y', doc='y-ref', dtype=RefSpec('Qux', reftype='object')) + DtypeSpec("x", doc="x-value", dtype="int"), + DtypeSpec("y", doc="y-ref", dtype=RefSpec("Qux", reftype="object")), ], - attributes={'data_type': 'Baz'} + attributes={"data_type": "Baz"}, ) foo = GroupBuilder( - name='foo', + name="foo", datasets=[baz, qux1, qux2], - attributes={'data_type': 'Foo'} + attributes={"data_type": "Foo"}, ) self.runBuilderRoundTrip(foo)