diff --git a/external_confs/dm/acs.yaml b/external_confs/dm/acs.yaml new file mode 100644 index 00000000..1c013720 --- /dev/null +++ b/external_confs/dm/acs.yaml @@ -0,0 +1,4 @@ +stratified_sampler: approx_group +num_workers: 1 +batch_size_tr: 128 +batch_size_te: 100000 diff --git a/external_confs/ds/acs/employment_dis_fl.yaml b/external_confs/ds/acs/employment_dis_fl.yaml new file mode 100644 index 00000000..3315fe45 --- /dev/null +++ b/external_confs/ds/acs/employment_dis_fl.yaml @@ -0,0 +1,10 @@ +--- +defaults: + - acs + +setting: employment_disability +survey_year: YEAR_2018 +states: + - FL +survey: PERSON +horizon: ONE_YEAR diff --git a/external_confs/experiment/acs/fcn.yaml b/external_confs/experiment/acs/fcn.yaml new file mode 100644 index 00000000..5c84bc9d --- /dev/null +++ b/external_confs/experiment/acs/fcn.yaml @@ -0,0 +1,38 @@ +# @package _global_ + +defaults: + - /alg: supmatch_no_disc + - override /dm: acs + - override /ds: acs/employment_dis_fl + - override /split: acs/employment_dis + - override /labeller: gt + - override /ae_arch: fcn + +alg: + use_amp: False + pred: + lr: ${ ae.lr } + steps: 10000 + val_freq: 1000 + log_freq: ${ alg.steps } + # num_disc_updates: 3 + # disc_loss_w: 0.03 + # ga_steps: 1 + # max_grad_norm: null + +ae: + recon_loss: l2 + zs_dim: 1 + +ae_opt: + lr: 1.e-4 + optimizer_cls: ADAM + weight_decay: 0 + +ae_arch: + hidden_dim: 64 + latent_dim: 64 + num_hidden: 2 + +eval: + batch_size: 128 diff --git a/external_confs/split/acs/employment_dis.yaml b/external_confs/split/acs/employment_dis.yaml new file mode 100644 index 00000000..ee8cdae9 --- /dev/null +++ b/external_confs/split/acs/employment_dis.yaml @@ -0,0 +1,8 @@ +--- +defaults: + - tabular + +seed: 0 +train_props: + 1: + 0: 0.0 diff --git a/poetry.lock b/poetry.lock index 998ea5a7..6d1ca546 100644 --- a/poetry.lock +++ b/poetry.lock @@ -543,6 +543,7 @@ files = [ [package.dependencies] filelock = "*" +folktables = {version = ">=0.0.12", optional = true, markers = "extra == \"data\" or extra == \"all\""} jinja2 = "*" joblib = ">=1.1.0,<2.0.0" networkx = "*" @@ -578,6 +579,23 @@ files = [ docs = ["furo (>=2023.3.27)", "sphinx (>=6.1.3)", "sphinx-autodoc-typehints (>=1.23,!=1.23.4)"] testing = ["covdefaults (>=2.3)", "coverage (>=7.2.3)", "diff-cover (>=7.5)", "pytest (>=7.3.1)", "pytest-cov (>=4)", "pytest-mock (>=3.10)", "pytest-timeout (>=2.1)"] +[[package]] +name = "folktables" +version = "0.0.12" +description = "New machine learning benchmarks from tabular datasets." +optional = false +python-versions = ">=3.7" +files = [ + {file = "folktables-0.0.12-py3-none-any.whl", hash = "sha256:979cda1900094b845ab3a8a3ae1b848f0138b780d5f8d17eeb6eb04c3c0c6617"}, + {file = "folktables-0.0.12.tar.gz", hash = "sha256:e83dde0cbcdd54c7c39b175006a50bdfc4adc351f69d4389f82aaba3eee02115"}, +] + +[package.dependencies] +numpy = "*" +pandas = "*" +requests = "*" +scikit-learn = "*" + [[package]] name = "fonttools" version = "4.39.4" @@ -2855,18 +2873,20 @@ opt-einsum = ["opt-einsum (>=3.3)"] [[package]] name = "torch-conduit" -version = "0.4.1" +version = "0.4.2" description = "Lightweight framework for dataloading with PyTorch and channeling the power of PyTorch Lightning" optional = false python-versions = ">=3.10,<3.13" files = [ - {file = "torch_conduit-0.4.1-py3-none-any.whl", hash = "sha256:d9e61232a7a017fd1ed7a9a0e080fb321b10ea816608518523b9c8ae9c9db44a"}, - {file = "torch_conduit-0.4.1.tar.gz", hash = "sha256:4201f46355f7397c6e0bc1b39442b83daf207a9427ed93eebf0c1cbf59f67bd2"}, + {file = "torch_conduit-0.4.2-py3-none-any.whl", hash = "sha256:ff8587f60d4fc79e298e0d883364a33710c1861ba8de2a5b4a6b17ddc3c46054"}, + {file = "torch_conduit-0.4.2.tar.gz", hash = "sha256:9e4b3091f4f276b829a51a81639fa27ecaf7373958adb0f8306f3f01e36fc076"}, ] [package.dependencies] albumentations = {version = ">=1.0.0,<2.0.0", optional = true, markers = "extra == \"image\" or extra == \"all\""} attrs = ">=21.2.0" +ethicml = {version = ">=1.2.1,<2.0.0", extras = ["data"], optional = true, markers = "extra == \"fair\" or extra == \"all\""} +folktables = {version = ">=0.0.12,<0.0.13", optional = true, markers = "extra == \"fair\" or extra == \"all\""} numpy = ">=1.22.3,<2.0.0" opencv-python = {version = ">=4.5.3,<5.0.0", optional = true, markers = "extra == \"image\" or extra == \"all\""} pandas = ">=1.3.3,<3.0" @@ -3011,13 +3031,13 @@ files = [ [[package]] name = "typing-extensions" -version = "4.5.0" -description = "Backported and Experimental Type Hints for Python 3.7+" +version = "4.10.0" +description = "Backported and Experimental Type Hints for Python 3.8+" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "typing_extensions-4.5.0-py3-none-any.whl", hash = "sha256:fb33085c39dd998ac16d1431ebc293a8b3eedd00fd4a32de0ff79002c19511b4"}, - {file = "typing_extensions-4.5.0.tar.gz", hash = "sha256:5cb5f4a79139d699607b3ef622a1dedafa84e115ab0024e0d9c044a9479ca7cb"}, + {file = "typing_extensions-4.10.0-py3-none-any.whl", hash = "sha256:69b1a937c3a517342112fb4c6df7e72fc39a38e7891a5730ed4985b5214b5475"}, + {file = "typing_extensions-4.10.0.tar.gz", hash = "sha256:b0abd7c89e8fb96f98db18d86106ff1d90ab692004eb746cf6eda2682f91b3cb"}, ] [[package]] @@ -3238,4 +3258,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "0db240d14519c4a8a2078b91e59ce42803c85dbceb5d91b719ba7f6836c875f5" +content-hash = "8e3dc2a8b7a8b97ff2ee17fec3b445c7ec645340c40c5f39b7a1cf552b54e5f4" diff --git a/pyproject.toml b/pyproject.toml index 0a01cebe..4bf6849e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,8 @@ scikit-image = ">=0.14" scikit_learn = { version = ">=0.20.1" } scipy = { version = ">=1.2.1" } seaborn = { version = ">=0.9.0" } -torch-conduit = { version = ">=0.3.4", extras = ["image"] } +torch-conduit = { version = ">=0.4.2", extras = ["image", "fair"] } +typing_extensions = ">= 4.10" tqdm = { version = ">=4.31.1" } typer = "*" diff --git a/src/data/__init__.py b/src/data/__init__.py index f2124c5f..5ad2129a 100644 --- a/src/data/__init__.py +++ b/src/data/__init__.py @@ -1,6 +1,6 @@ from .common import * from .data_module import * -from .nico_plus_plus import * +from .factories import * from .nih import * from .splitter import * from .utils import * diff --git a/src/data/factories.py b/src/data/factories.py new file mode 100644 index 00000000..170bd1ff --- /dev/null +++ b/src/data/factories.py @@ -0,0 +1,50 @@ +"""Dataset factories.""" + +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Optional, Union +from typing_extensions import override + +from conduit.data.datasets.vision import NICOPP, NicoPPTarget +from conduit.fair.data.datasets import ( + ACSDataset, + ACSHorizon, + ACSSetting, + ACSState, + ACSSurvey, + ACSSurveyYear, +) + +from src.data.common import DatasetFactory + +__all__ = ["NICOPPCfg"] + + +@dataclass +class NICOPPCfg(DatasetFactory): + root: Union[Path, str] + target_attrs: Optional[list[NicoPPTarget]] = None + transform: Any = None # Optional[Union[Compose, BasicTransform, Callable[[Image], Any]]] + + @override + def __call__(self) -> NICOPP: + return NICOPP(root=self.root, transform=self.transform, superclasses=self.target_attrs) + + +@dataclass +class ACSCfg(DatasetFactory): + setting: ACSSetting + survey_year: ACSSurveyYear = ACSSurveyYear.YEAR_2018 + horizon: ACSHorizon = ACSHorizon.ONE_YEAR + survey: ACSSurvey = ACSSurvey.PERSON + states: list[ACSState] = field(default_factory=lambda: [ACSState.AL]) + + @override + def __call__(self) -> ACSDataset: + return ACSDataset( + setting=self.setting, + survey_year=self.survey_year, + horizon=self.horizon, + survey=self.survey, + states=self.states, + ) diff --git a/src/data/nico_plus_plus.py b/src/data/nico_plus_plus.py deleted file mode 100644 index d3673f46..00000000 --- a/src/data/nico_plus_plus.py +++ /dev/null @@ -1,24 +0,0 @@ -"""NICO Dataset.""" -from dataclasses import dataclass -from pathlib import Path -from typing import Any, Optional, Union -from typing_extensions import override - -from conduit.data.datasets.vision import CdtVisionDataset, NICOPP, NicoPPTarget -from conduit.data.structures import TernarySample -from torch import Tensor - -from src.data.common import DatasetFactory - -__all__ = ["NICOPPCfg"] - - -@dataclass -class NICOPPCfg(DatasetFactory): - root: Union[Path, str] - target_attrs: Optional[list[NicoPPTarget]] = None - transform: Any = None # Optional[Union[Compose, BasicTransform, Callable[[Image], Any]]] - - @override - def __call__(self) -> CdtVisionDataset[TernarySample, Tensor, Tensor]: - return NICOPP(root=self.root, transform=self.transform, superclasses=self.target_attrs) diff --git a/src/data/splitter.py b/src/data/splitter.py index 48c22a18..e9634a0b 100644 --- a/src/data/splitter.py +++ b/src/data/splitter.py @@ -1,5 +1,7 @@ -from abc import abstractmethod +from abc import ABC, abstractmethod +from collections.abc import Callable from dataclasses import dataclass +from enum import Enum from pathlib import Path import platform from tempfile import TemporaryDirectory @@ -10,7 +12,10 @@ from conduit.data.datasets import random_split from conduit.data.datasets.utils import stratified_split from conduit.data.datasets.vision import CdtVisionDataset, ImageTform, PillowTform +from conduit.fair.data.datasets import ACSDataset +from conduit.transforms import MinMaxNormalize, TabularNormalize, ZScoreNormalize from loguru import logger +from ranzen import some import torch from torch import Tensor import torchvision.transforms as T @@ -24,14 +29,22 @@ "DataSplitter", "RandomSplitter", "SplitFromArtifact", + "TabularSplitter", "load_split_inds_from_artifact", "save_split_inds_as_artifact", ] @dataclass(eq=False) -class DataSplitter: - """How to split the data into train/test/dep.""" +class DataSplitter(ABC): + @abstractmethod + def __call__(self, dataset: D) -> TrainDepTestSplit[D]: + """Split the dataset into train/deployment/test.""" + + +@dataclass(eq=False) +class _VisionDataSplitter(DataSplitter): + """Common methods for transforming vision datasets.""" transductive: bool = False """Whether to include the test data in the pool of unlabelled data.""" @@ -133,7 +146,7 @@ def save_split_inds_as_artifact( @dataclass(eq=False) -class RandomSplitter(DataSplitter): +class RandomSplitter(_VisionDataSplitter): seed: int = 42 dep_prop: float = 0.4 test_prop: float = 0.2 @@ -259,7 +272,7 @@ def load_split_inds_from_artifact( @dataclass(eq=False, kw_only=True) -class SplitFromArtifact(DataSplitter): +class SplitFromArtifact(_VisionDataSplitter): artifact_name: str version: Optional[int] = None @@ -272,3 +285,41 @@ def split(self, dataset: D) -> TrainDepTestSplit[D]: dep_data = dataset.subset(splits["dep"]) test_data = dataset.subset(splits["test"]) return TrainDepTestSplit(train=train_data, deployment=dep_data, test=test_data) + + +class TabularTform(Enum): + zscore_normalize = (ZScoreNormalize,) + minmax_normalize = (MinMaxNormalize,) + + def __init__(self, tform: Callable[[], TabularNormalize]) -> None: + self.tf = tform + + +@dataclass(eq=False) +class TabularSplitter(DataSplitter): + """Split and transform tabular datasets.""" + + seed: int + train_props: dict[int, dict[int, float]] | None = None + dep_prop: float = 0.2 + test_prop: float = 0.1 + transform: TabularTform | None = TabularTform.zscore_normalize + + @override + def __call__(self, dataset: D) -> TrainDepTestSplit[D]: + if not isinstance(dataset, ACSDataset): + raise NotImplementedError("TabularSplitter only supports splitting of `ACSDataset`.") + + train, dep, test = dataset.subsampled_split( + train_props=self.train_props, + val_prop=self.dep_prop, + test_prop=self.test_prop, + seed=self.seed, + ) + if some(tf_type := self.transform): + tf = tf_type.tf() + train.fit_transform_(tf) + dep.transform_(tf) + test.transform_(tf) + + return TrainDepTestSplit(train=train, deployment=dep, test=test) diff --git a/src/relay/base.py b/src/relay/base.py index 4f9ccbe8..bb8bafb2 100644 --- a/src/relay/base.py +++ b/src/relay/base.py @@ -8,7 +8,7 @@ from src.data import DataModule, DataModuleConf, RandomSplitter, SplitFromArtifact from src.data.common import Dataset -from src.data.splitter import DataSplitter +from src.data.splitter import DataSplitter, TabularSplitter from src.labelling import Labeller from src.logging import WandbConf @@ -23,7 +23,11 @@ class BaseRelay: seed: int = 0 options: ClassVar[dict[str, dict[str, type]]] = { - "split": {"random": RandomSplitter, "artifact": SplitFromArtifact} + "split": { + "random": RandomSplitter, + "artifact": SplitFromArtifact, + "tabular": TabularSplitter, + } } def init_dm( diff --git a/src/relay/fs.py b/src/relay/fs.py index 2c41a6a2..ebbbb127 100644 --- a/src/relay/fs.py +++ b/src/relay/fs.py @@ -8,6 +8,7 @@ from src.arch.backbones import DenseNet, LinearResNet, ResNet, SimpleCNN from src.arch.predictors.fcn import Fcn from src.data import DatasetFactory, NICOPPCfg, NIHChestXRayDatasetCfg +from src.data.factories import ACSCfg from src.hydra_confs.datasets import Camelyon17Cfg, CelebACfg, ColoredMNISTCfg from src.labelling.pipeline import ( CentroidalLabelNoiser, @@ -44,6 +45,7 @@ class FsRelay(BaseRelay): options: ClassVar[dict[str, dict[str, type]]] = BaseRelay.options | { "ds": { + "acs": ACSCfg, "cmnist": ColoredMNISTCfg, "celeba": CelebACfg, "camelyon17": Camelyon17Cfg, diff --git a/src/relay/label.py b/src/relay/label.py index 0fed1397..8e745997 100644 --- a/src/relay/label.py +++ b/src/relay/label.py @@ -3,6 +3,7 @@ from attrs import define, field from src.data.common import DatasetFactory +from src.data.factories import ACSCfg from src.data.nih import NIHChestXRayDatasetCfg from src.data.utils import resolve_device from src.hydra_confs.datasets import Camelyon17Cfg, CelebACfg, ColoredMNISTCfg @@ -31,6 +32,7 @@ class LabelRelay(BaseRelay): options: ClassVar[dict[str, dict[str, type]]] = BaseRelay.options | { "ds": { + "acs": ACSCfg, "cmnist": ColoredMNISTCfg, "celeba": CelebACfg, "camelyon17": Camelyon17Cfg, diff --git a/src/relay/split.py b/src/relay/split.py index d73363e7..7c888110 100644 --- a/src/relay/split.py +++ b/src/relay/split.py @@ -4,7 +4,7 @@ from src.data import RandomSplitter from src.data.common import DatasetFactory -from src.data.nico_plus_plus import NICOPPCfg +from src.data.factories import NICOPPCfg from src.data.nih import NIHChestXRayDatasetCfg from src.hydra_confs.datasets import Camelyon17Cfg, CelebACfg from src.logging import WandbConf diff --git a/src/relay/supmatch.py b/src/relay/supmatch.py index e06e6b05..c5b4dc68 100644 --- a/src/relay/supmatch.py +++ b/src/relay/supmatch.py @@ -21,7 +21,7 @@ from src.arch.predictors.base import PredictorFactory from src.arch.predictors.fcn import Fcn, SetFcn from src.data.common import DatasetFactory -from src.data.nico_plus_plus import NICOPPCfg +from src.data.factories import ACSCfg, NICOPPCfg from src.data.nih import NIHChestXRayDatasetCfg from src.hydra_confs.datasets import Camelyon17Cfg, CelebACfg, ColoredMNISTCfg from src.labelling.pipeline import ( @@ -71,6 +71,7 @@ class SupMatchRelay(BaseRelay): options: ClassVar[dict[str, dict[str, type]]] = BaseRelay.options | { "scorer": {"neural": NeuralScorer, "none": NullScorer}, "ds": { + "acs": ACSCfg, "cmnist": ColoredMNISTCfg, "celeba": CelebACfg, "camelyon17": Camelyon17Cfg,