From 8439d04eda8b68feb83c97f5186dfa37f733bbc1 Mon Sep 17 00:00:00 2001 From: mooniean Date: Tue, 30 Jan 2024 11:11:02 +0000 Subject: [PATCH] lint fixed, pre-commit ran --- pyproject.toml | 1 + src/caked/__init__.py | 1 - src/caked/base.py | 24 ++++++++++++------------ src/caked/dataloader.py | 20 +++++++++++++++----- tests/test_io.py | 15 +++++++++++---- 5 files changed, 39 insertions(+), 22 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 5e38036..a1b51cc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,7 @@ dependencies = ["torch", "numpy"] test = [ "pytest >=6", "pytest-cov >=3", + "pre-commit", ] dev = [ "pytest >=6", diff --git a/src/caked/__init__.py b/src/caked/__init__.py index cc5ffaf..5b54632 100644 --- a/src/caked/__init__.py +++ b/src/caked/__init__.py @@ -3,6 +3,5 @@ """ from __future__ import annotations - __all__ = ("__version__",) __version__ = "0.1.0" diff --git a/src/caked/base.py b/src/caked/base.py index a3a8d7f..d445b92 100644 --- a/src/caked/base.py +++ b/src/caked/base.py @@ -1,11 +1,17 @@ -from abc import ABC, abstractmethod -from pathlib import Path +from __future__ import annotations -from torch.utils.data import DataLoader, Dataset +from abc import ABC, abstractmethod class AbstractDataLoader(ABC): - def __init__(self, pipeline: str, classes: list[str], dataset_size: int, save_to_disk: bool, training: bool): + def __init__( + self, + pipeline: str, + classes: list[str], + dataset_size: int, + save_to_disk: bool, + training: bool, + ): self.pipeline = pipeline self.classes = classes self.dataset_size = dataset_size @@ -26,16 +32,10 @@ def get_loader(self, split_size: float, batch_size: int): class AbstractDataset(ABC): - def __init__(self, origin: str, classes: Path) -> None: - pass - - def __len__(self) -> int: - pass - @abstractmethod - def set_len(self, length:int): + def set_len(self, length: int): pass @abstractmethod - def augment(self, augment:bool, aug_type:str): + def augment(self, augment: bool, aug_type: str): pass diff --git a/src/caked/dataloader.py b/src/caked/dataloader.py index b196d44..6aa1560 100644 --- a/src/caked/dataloader.py +++ b/src/caked/dataloader.py @@ -1,14 +1,24 @@ +from __future__ import annotations + from .base import AbstractDataLoader + class DiskDataLoader(AbstractDataLoader): - def __init__(self, pipeline: str, classes: list[str], dataset_size: int, save_to_disk: bool, training: bool) -> None: + def __init__( + self, + pipeline: str, + classes: list[str], + dataset_size: int, + save_to_disk: bool, + training: bool, + ) -> None: super().__init__(pipeline, classes, dataset_size, save_to_disk, training) - + def load(self): return super().load() - + def process(self): return super().process() - + def get_loader(self, split_size: float, batch_size: int): - return super().get_loader(split_size, batch_size) \ No newline at end of file + return super().get_loader(split_size, batch_size) diff --git a/tests/test_io.py b/tests/test_io.py index 395829a..edfa623 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -1,7 +1,14 @@ -from src.caked.dataloader import DiskDataLoader -import pytest +from __future__ import annotations + +from caked.dataloader import DiskDataLoader def test_class_instantiation(): - test_loader = DiskDataLoader(pipeline="test", classes="test", dataset_size=3, save_to_disk=False, training=True) - assert isinstance(test_loader, DiskDataLoader) \ No newline at end of file + test_loader = DiskDataLoader( + pipeline="test", + classes=["test"], + dataset_size=3, + save_to_disk=False, + training=True, + ) + assert isinstance(test_loader, DiskDataLoader)