From c26bbc736be94f5bf90903221382e0ad10c40ed6 Mon Sep 17 00:00:00 2001 From: mooniean Date: Mon, 5 Feb 2024 17:46:42 +0000 Subject: [PATCH] adding IO tests --- src/caked/base.py | 2 +- src/caked/dataloader.py | 17 ++++---- tests/test_io.py | 90 +++++++++++++++++++++++++++++++++++++++-- 3 files changed, 98 insertions(+), 11 deletions(-) diff --git a/src/caked/base.py b/src/caked/base.py index 5b21bb6..767bc8a 100644 --- a/src/caked/base.py +++ b/src/caked/base.py @@ -21,7 +21,7 @@ def __init__( self.training = training @abstractmethod - def load(self): + def load(self, datapath, datatype): pass @abstractmethod diff --git a/src/caked/dataloader.py b/src/caked/dataloader.py index 7eb01d9..590d063 100644 --- a/src/caked/dataloader.py +++ b/src/caked/dataloader.py @@ -29,9 +29,11 @@ def __init__( self.training = training self.pipeline = pipeline if classes is None: - classes = [] + self.classes = [] + else: + self.classes = classes - def load(self, datapath, datatype): + def load(self, datapath, datatype) -> None: paths = [f for f in os.listdir(datapath) if "." + datatype in f] random.shuffle(paths) @@ -57,14 +59,16 @@ def load(self, datapath, datatype): (np.asarray(ids)[~class_check]), ) - # subset affinity matrix with only the relevant classes - - paths = [p for p in paths for c in self.classes if c in p.split("_")[0]] + paths = [ + Path(datapath) / p + for p in paths + for c in self.classes + if c in p.split("_")[0] + ] if self.dataset_size is not None: paths = paths[: self.dataset_size] self.dataset = DiskDataset(paths=paths, datatype=datatype) - return super().load() def process(self): return super().process() @@ -92,7 +96,6 @@ def __init__( self.transform = input_transform self.datatype = datatype self.shiftmin = shiftmin - super().__init__() def __len__(self): return len(self.paths) diff --git a/tests/test_io.py b/tests/test_io.py index e584474..174943a 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -1,19 +1,103 @@ from __future__ import annotations +import os +import random +from pathlib import Path + +import numpy as np +import pytest +import torch +from tests import testdata_mrc + from caked.dataloader import DiskDataLoader, DiskDataset +ORIG_DIR = Path.cwd() +TEST_DATA_MRC = Path(testdata_mrc.__file__).parent +DISK_PIPELINE = "disk" +DATASET_SIZE_ALL = None +DATASET_SIZE_SOME = 3 +DISK_CLASSES_FULL = ["1b23", "1dfo", "1dkg", "1e3p"] +DISK_CLASSES_SOME = ["1b23", "1dkg"] +DISK_CLASSES_MISSING = ["2b3a", "1b23"] +DISK_CLASSES_NONE = None +DATATYPE_MRC = "mrc" + def test_class_instantiation(): test_loader = DiskDataLoader( - pipeline="test", - classes=["test"], - dataset_size=3, + pipeline=DISK_PIPELINE, + classes=DISK_CLASSES_SOME, + dataset_size=DATASET_SIZE_SOME, save_to_disk=False, training=True, ) assert isinstance(test_loader, DiskDataLoader) + assert test_loader.pipeline == DISK_PIPELINE def test_dataset_instantiation(): test_dataset = DiskDataset(paths=["test"]) assert isinstance(test_dataset, DiskDataset) + + +def test_load_dataset_no_classes(): + test_loader = DiskDataLoader( + pipeline=DISK_PIPELINE, classes=DISK_CLASSES_NONE, dataset_size=DATASET_SIZE_ALL + ) + test_loader.load(datapath=TEST_DATA_MRC, datatype=DATATYPE_MRC) + assert isinstance(test_loader.dataset, DiskDataset) + assert len(test_loader.classes) == len(DISK_CLASSES_FULL) + assert all(a == b for a, b in zip(test_loader.classes, DISK_CLASSES_FULL)) + + +def test_load_dataset_all_classes(): + test_loader = DiskDataLoader( + pipeline=DISK_PIPELINE, classes=DISK_CLASSES_FULL, dataset_size=DATASET_SIZE_ALL + ) + test_loader.load(datapath=TEST_DATA_MRC, datatype=DATATYPE_MRC) + assert isinstance(test_loader.dataset, DiskDataset) + assert len(test_loader.classes) == len(DISK_CLASSES_FULL) + assert all(a == b for a, b in zip(test_loader.classes, DISK_CLASSES_FULL)) + + +def test_load_dataset_some_classes(): + test_loader = DiskDataLoader( + pipeline=DISK_PIPELINE, classes=DISK_CLASSES_SOME, dataset_size=DATASET_SIZE_ALL + ) + test_loader.load(datapath=TEST_DATA_MRC, datatype=DATATYPE_MRC) + assert isinstance(test_loader.dataset, DiskDataset) + assert len(test_loader.classes) == len(DISK_CLASSES_SOME) + assert all(a == b for a, b in zip(test_loader.classes, DISK_CLASSES_SOME)) + + +def test_load_dataset_missing_class(): + test_loader = DiskDataLoader( + pipeline=DISK_PIPELINE, + classes=DISK_CLASSES_MISSING, + dataset_size=DATASET_SIZE_ALL, + ) + paths = [f for f in os.listdir(TEST_DATA_MRC) if "." + DATATYPE_MRC in f] + + random.shuffle(paths) + + # ids right now depend on the data being saved with a certain format (id in the first part of the name, separated by _) + # TODO: make this more general/document in the README + ids = np.unique([f.split("_")[0] for f in paths]) + assert test_loader.classes == DISK_CLASSES_MISSING + classes = test_loader.classes + class_check = np.in1d(classes, ids) + if not np.all(class_check): + print(class_check) + with pytest.raises(Exception, match=r".* Missing classes: .*"): + test_loader.load(datapath=TEST_DATA_MRC, datatype=DATATYPE_MRC) + + +def test_one_image(): + test_loader = DiskDataLoader( + pipeline=DISK_PIPELINE, classes=DISK_CLASSES_NONE, dataset_size=DATASET_SIZE_ALL + ) + test_loader.load(datapath=TEST_DATA_MRC, datatype=DATATYPE_MRC) + test_dataset = test_loader.dataset + test_item_image, test_item_name = test_dataset.__getitem__(1) + assert test_item_name in DISK_CLASSES_FULL + assert isinstance(test_item_image, torch.Tensor)