Skip to content

Commit

Permalink
adding IO tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mooniean committed Feb 5, 2024
1 parent fe6d1e6 commit c26bbc7
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 11 deletions.
2 changes: 1 addition & 1 deletion src/caked/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __init__(
self.training = training

@abstractmethod
def load(self):
def load(self, datapath, datatype):
pass

@abstractmethod
Expand Down
17 changes: 10 additions & 7 deletions src/caked/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,11 @@ def __init__(
self.training = training
self.pipeline = pipeline
if classes is None:
classes = []
self.classes = []
else:
self.classes = classes

def load(self, datapath, datatype):
def load(self, datapath, datatype) -> None:
paths = [f for f in os.listdir(datapath) if "." + datatype in f]

random.shuffle(paths)
Expand All @@ -57,14 +59,16 @@ def load(self, datapath, datatype):
(np.asarray(ids)[~class_check]),
)

# subset affinity matrix with only the relevant classes

paths = [p for p in paths for c in self.classes if c in p.split("_")[0]]
paths = [
Path(datapath) / p
for p in paths
for c in self.classes
if c in p.split("_")[0]
]
if self.dataset_size is not None:
paths = paths[: self.dataset_size]

self.dataset = DiskDataset(paths=paths, datatype=datatype)
return super().load()

def process(self):
return super().process()
Expand Down Expand Up @@ -92,7 +96,6 @@ def __init__(
self.transform = input_transform
self.datatype = datatype
self.shiftmin = shiftmin
super().__init__()

def __len__(self):
return len(self.paths)
Expand Down
90 changes: 87 additions & 3 deletions tests/test_io.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,103 @@
from __future__ import annotations

import os
import random
from pathlib import Path

import numpy as np
import pytest
import torch
from tests import testdata_mrc

from caked.dataloader import DiskDataLoader, DiskDataset

ORIG_DIR = Path.cwd()
TEST_DATA_MRC = Path(testdata_mrc.__file__).parent
DISK_PIPELINE = "disk"
DATASET_SIZE_ALL = None
DATASET_SIZE_SOME = 3
DISK_CLASSES_FULL = ["1b23", "1dfo", "1dkg", "1e3p"]
DISK_CLASSES_SOME = ["1b23", "1dkg"]
DISK_CLASSES_MISSING = ["2b3a", "1b23"]
DISK_CLASSES_NONE = None
DATATYPE_MRC = "mrc"


def test_class_instantiation():
test_loader = DiskDataLoader(
pipeline="test",
classes=["test"],
dataset_size=3,
pipeline=DISK_PIPELINE,
classes=DISK_CLASSES_SOME,
dataset_size=DATASET_SIZE_SOME,
save_to_disk=False,
training=True,
)
assert isinstance(test_loader, DiskDataLoader)
assert test_loader.pipeline == DISK_PIPELINE


def test_dataset_instantiation():
test_dataset = DiskDataset(paths=["test"])
assert isinstance(test_dataset, DiskDataset)


def test_load_dataset_no_classes():
test_loader = DiskDataLoader(
pipeline=DISK_PIPELINE, classes=DISK_CLASSES_NONE, dataset_size=DATASET_SIZE_ALL
)
test_loader.load(datapath=TEST_DATA_MRC, datatype=DATATYPE_MRC)
assert isinstance(test_loader.dataset, DiskDataset)
assert len(test_loader.classes) == len(DISK_CLASSES_FULL)
assert all(a == b for a, b in zip(test_loader.classes, DISK_CLASSES_FULL))


def test_load_dataset_all_classes():
test_loader = DiskDataLoader(
pipeline=DISK_PIPELINE, classes=DISK_CLASSES_FULL, dataset_size=DATASET_SIZE_ALL
)
test_loader.load(datapath=TEST_DATA_MRC, datatype=DATATYPE_MRC)
assert isinstance(test_loader.dataset, DiskDataset)
assert len(test_loader.classes) == len(DISK_CLASSES_FULL)
assert all(a == b for a, b in zip(test_loader.classes, DISK_CLASSES_FULL))


def test_load_dataset_some_classes():
test_loader = DiskDataLoader(
pipeline=DISK_PIPELINE, classes=DISK_CLASSES_SOME, dataset_size=DATASET_SIZE_ALL
)
test_loader.load(datapath=TEST_DATA_MRC, datatype=DATATYPE_MRC)
assert isinstance(test_loader.dataset, DiskDataset)
assert len(test_loader.classes) == len(DISK_CLASSES_SOME)
assert all(a == b for a, b in zip(test_loader.classes, DISK_CLASSES_SOME))


def test_load_dataset_missing_class():
test_loader = DiskDataLoader(
pipeline=DISK_PIPELINE,
classes=DISK_CLASSES_MISSING,
dataset_size=DATASET_SIZE_ALL,
)
paths = [f for f in os.listdir(TEST_DATA_MRC) if "." + DATATYPE_MRC in f]

random.shuffle(paths)

# ids right now depend on the data being saved with a certain format (id in the first part of the name, separated by _)
# TODO: make this more general/document in the README
ids = np.unique([f.split("_")[0] for f in paths])
assert test_loader.classes == DISK_CLASSES_MISSING
classes = test_loader.classes
class_check = np.in1d(classes, ids)
if not np.all(class_check):
print(class_check)
with pytest.raises(Exception, match=r".* Missing classes: .*"):
test_loader.load(datapath=TEST_DATA_MRC, datatype=DATATYPE_MRC)


def test_one_image():
test_loader = DiskDataLoader(
pipeline=DISK_PIPELINE, classes=DISK_CLASSES_NONE, dataset_size=DATASET_SIZE_ALL
)
test_loader.load(datapath=TEST_DATA_MRC, datatype=DATATYPE_MRC)
test_dataset = test_loader.dataset
test_item_image, test_item_name = test_dataset.__getitem__(1)
assert test_item_name in DISK_CLASSES_FULL
assert isinstance(test_item_image, torch.Tensor)

0 comments on commit c26bbc7

Please sign in to comment.