Skip to content

Commit

Permalink
lint fixed, pre-commit ran
Browse files Browse the repository at this point in the history
  • Loading branch information
mooniean committed Jan 30, 2024
1 parent 87f109b commit 8439d04
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 22 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ dependencies = ["torch", "numpy"]
test = [
"pytest >=6",
"pytest-cov >=3",
"pre-commit",
]
dev = [
"pytest >=6",
Expand Down
1 change: 0 additions & 1 deletion src/caked/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,5 @@
"""
from __future__ import annotations


__all__ = ("__version__",)
__version__ = "0.1.0"
24 changes: 12 additions & 12 deletions src/caked/base.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
from abc import ABC, abstractmethod
from pathlib import Path
from __future__ import annotations

from torch.utils.data import DataLoader, Dataset
from abc import ABC, abstractmethod


class AbstractDataLoader(ABC):
def __init__(self, pipeline: str, classes: list[str], dataset_size: int, save_to_disk: bool, training: bool):
def __init__(
self,
pipeline: str,
classes: list[str],
dataset_size: int,
save_to_disk: bool,
training: bool,
):
self.pipeline = pipeline
self.classes = classes
self.dataset_size = dataset_size
Expand All @@ -26,16 +32,10 @@ def get_loader(self, split_size: float, batch_size: int):


class AbstractDataset(ABC):
def __init__(self, origin: str, classes: Path) -> None:
pass

def __len__(self) -> int:
pass

@abstractmethod
def set_len(self, length:int):
def set_len(self, length: int):
pass

@abstractmethod
def augment(self, augment:bool, aug_type:str):
def augment(self, augment: bool, aug_type: str):
pass
20 changes: 15 additions & 5 deletions src/caked/dataloader.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,24 @@
from __future__ import annotations

from .base import AbstractDataLoader


class DiskDataLoader(AbstractDataLoader):
def __init__(self, pipeline: str, classes: list[str], dataset_size: int, save_to_disk: bool, training: bool) -> None:
def __init__(
self,
pipeline: str,
classes: list[str],
dataset_size: int,
save_to_disk: bool,
training: bool,
) -> None:
super().__init__(pipeline, classes, dataset_size, save_to_disk, training)

def load(self):
return super().load()

def process(self):
return super().process()

def get_loader(self, split_size: float, batch_size: int):
return super().get_loader(split_size, batch_size)
return super().get_loader(split_size, batch_size)
15 changes: 11 additions & 4 deletions tests/test_io.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
from src.caked.dataloader import DiskDataLoader
import pytest
from __future__ import annotations

from caked.dataloader import DiskDataLoader


def test_class_instantiation():
test_loader = DiskDataLoader(pipeline="test", classes="test", dataset_size=3, save_to_disk=False, training=True)
assert isinstance(test_loader, DiskDataLoader)
test_loader = DiskDataLoader(
pipeline="test",
classes=["test"],
dataset_size=3,
save_to_disk=False,
training=True,
)
assert isinstance(test_loader, DiskDataLoader)

0 comments on commit 8439d04

Please sign in to comment.