diff --git a/src/beignet/datasets/_sabdab_dataset.py b/src/beignet/datasets/_sabdab_dataset.py index 3ae7488940..ff715e47d6 100644 --- a/src/beignet/datasets/_sabdab_dataset.py +++ b/src/beignet/datasets/_sabdab_dataset.py @@ -1,12 +1,12 @@ from pathlib import Path from typing import Callable -from torch.utils.data import Dataset - from beignet.transforms import Transform +from ._tdc_dataset import TDCDataset + -class SAbDabDataset(Dataset): +class SAbDabDataset(TDCDataset): def __init__( self, root: str | Path, @@ -32,10 +32,14 @@ def __init__( target_transform : Callable | Transform | None Transforms the target. """ - raise NotImplementedError - - def __getitem__(self, index: int): - raise NotImplementedError - - def __len__(self) -> int: - raise NotImplementedError + super().__init__( + root=root, + download=download, + identifier=4167357, + suffix="csv", + checksum="md5:f4d0dba68859f7ae2a042bd90423b22b", + x_keys=["X1", "X2"], + y_keys=["Y"], + transform=transform, + target_transform=target_transform, + )