Skip to content

Commit

Permalink
Add ACS dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
tmke8 committed Feb 29, 2024
1 parent c9fb0cc commit 96c5d61
Show file tree
Hide file tree
Showing 15 changed files with 211 additions and 44 deletions.
4 changes: 4 additions & 0 deletions external_confs/dm/acs.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
stratified_sampler: approx_group
num_workers: 1
batch_size_tr: 128
batch_size_te: 100000
10 changes: 10 additions & 0 deletions external_confs/ds/acs/employment_dis_fl.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
---
defaults:
- acs

setting: employment_disability
survey_year: YEAR_2018
states:
- FL
survey: PERSON
horizon: ONE_YEAR
38 changes: 38 additions & 0 deletions external_confs/experiment/acs/fcn.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# @package _global_

defaults:
- /alg: supmatch_no_disc
- override /dm: acs
- override /ds: acs/employment_dis_fl
- override /split: acs/employment_dis
- override /labeller: gt
- override /ae_arch: fcn

alg:
use_amp: False
pred:
lr: ${ ae.lr }
steps: 10000
val_freq: 1000
log_freq: ${ alg.steps }
# num_disc_updates: 3
# disc_loss_w: 0.03
# ga_steps: 1
# max_grad_norm: null

ae:
recon_loss: l2
zs_dim: 1

ae_opt:
lr: 1.e-4
optimizer_cls: ADAM
weight_decay: 0

ae_arch:
hidden_dim: 64
latent_dim: 64
num_hidden: 2

eval:
batch_size: 128
8 changes: 8 additions & 0 deletions external_confs/split/acs/employment_dis.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
---
defaults:
- tabular

seed: 0
train_props:
1:
0: 0.0
38 changes: 29 additions & 9 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ scikit-image = ">=0.14"
scikit_learn = { version = ">=0.20.1" }
scipy = { version = ">=1.2.1" }
seaborn = { version = ">=0.9.0" }
torch-conduit = { version = ">=0.3.4", extras = ["image"] }
torch-conduit = { version = ">=0.4.2", extras = ["image", "fair"] }
typing_extensions = ">= 4.10"

tqdm = { version = ">=4.31.1" }
typer = "*"
Expand Down
2 changes: 1 addition & 1 deletion src/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .common import *
from .data_module import *
from .nico_plus_plus import *
from .factories import *
from .nih import *
from .splitter import *
from .utils import *
50 changes: 50 additions & 0 deletions src/data/factories.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
"""Dataset factories."""

from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Optional, Union
from typing_extensions import override

from conduit.data.datasets.vision import NICOPP, NicoPPTarget
from conduit.fair.data.datasets import (
ACSDataset,
ACSHorizon,
ACSSetting,
ACSState,
ACSSurvey,
ACSSurveyYear,
)

from src.data.common import DatasetFactory

__all__ = ["NICOPPCfg"]


@dataclass
class NICOPPCfg(DatasetFactory):
root: Union[Path, str]
target_attrs: Optional[list[NicoPPTarget]] = None
transform: Any = None # Optional[Union[Compose, BasicTransform, Callable[[Image], Any]]]

@override
def __call__(self) -> NICOPP:
return NICOPP(root=self.root, transform=self.transform, superclasses=self.target_attrs)


@dataclass
class ACSCfg(DatasetFactory):
setting: ACSSetting
survey_year: ACSSurveyYear = ACSSurveyYear.YEAR_2018
horizon: ACSHorizon = ACSHorizon.ONE_YEAR
survey: ACSSurvey = ACSSurvey.PERSON
states: list[ACSState] = field(default_factory=lambda: [ACSState.AL])

@override
def __call__(self) -> ACSDataset:
return ACSDataset(
setting=self.setting,
survey_year=self.survey_year,
horizon=self.horizon,
survey=self.survey,
states=self.states,
)
24 changes: 0 additions & 24 deletions src/data/nico_plus_plus.py

This file was deleted.

61 changes: 56 additions & 5 deletions src/data/splitter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from abc import abstractmethod
from abc import ABC, abstractmethod
from collections.abc import Callable
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
import platform
from tempfile import TemporaryDirectory
Expand All @@ -10,7 +12,10 @@
from conduit.data.datasets import random_split
from conduit.data.datasets.utils import stratified_split
from conduit.data.datasets.vision import CdtVisionDataset, ImageTform, PillowTform
from conduit.fair.data.datasets import ACSDataset
from conduit.transforms import MinMaxNormalize, TabularNormalize, ZScoreNormalize
from loguru import logger
from ranzen import some
import torch
from torch import Tensor
import torchvision.transforms as T
Expand All @@ -24,14 +29,22 @@
"DataSplitter",
"RandomSplitter",
"SplitFromArtifact",
"TabularSplitter",
"load_split_inds_from_artifact",
"save_split_inds_as_artifact",
]


@dataclass(eq=False)
class DataSplitter:
"""How to split the data into train/test/dep."""
class DataSplitter(ABC):
@abstractmethod
def __call__(self, dataset: D) -> TrainDepTestSplit[D]:
"""Split the dataset into train/deployment/test."""


@dataclass(eq=False)
class _VisionDataSplitter(DataSplitter):
"""Common methods for transforming vision datasets."""

transductive: bool = False
"""Whether to include the test data in the pool of unlabelled data."""
Expand Down Expand Up @@ -133,7 +146,7 @@ def save_split_inds_as_artifact(


@dataclass(eq=False)
class RandomSplitter(DataSplitter):
class RandomSplitter(_VisionDataSplitter):
seed: int = 42
dep_prop: float = 0.4
test_prop: float = 0.2
Expand Down Expand Up @@ -259,7 +272,7 @@ def load_split_inds_from_artifact(


@dataclass(eq=False, kw_only=True)
class SplitFromArtifact(DataSplitter):
class SplitFromArtifact(_VisionDataSplitter):
artifact_name: str
version: Optional[int] = None

Expand All @@ -272,3 +285,41 @@ def split(self, dataset: D) -> TrainDepTestSplit[D]:
dep_data = dataset.subset(splits["dep"])
test_data = dataset.subset(splits["test"])
return TrainDepTestSplit(train=train_data, deployment=dep_data, test=test_data)


class TabularTform(Enum):
zscore_normalize = (ZScoreNormalize,)
minmax_normalize = (MinMaxNormalize,)

def __init__(self, tform: Callable[[], TabularNormalize]) -> None:
self.tf = tform


@dataclass(eq=False)
class TabularSplitter(DataSplitter):
"""Split and transform tabular datasets."""

seed: int
train_props: dict[int, dict[int, float]] | None = None
dep_prop: float = 0.2
test_prop: float = 0.1
transform: TabularTform | None = TabularTform.zscore_normalize

@override
def __call__(self, dataset: D) -> TrainDepTestSplit[D]:
if not isinstance(dataset, ACSDataset):
raise NotImplementedError("TabularSplitter only supports splitting of `ACSDataset`.")

train, dep, test = dataset.subsampled_split(
train_props=self.train_props,
val_prop=self.dep_prop,
test_prop=self.test_prop,
seed=self.seed,
)
if some(tf_type := self.transform):
tf = tf_type.tf()
train.fit_transform_(tf)
dep.transform_(tf)
test.transform_(tf)

return TrainDepTestSplit(train=train, deployment=dep, test=test)
8 changes: 6 additions & 2 deletions src/relay/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from src.data import DataModule, DataModuleConf, RandomSplitter, SplitFromArtifact
from src.data.common import Dataset
from src.data.splitter import DataSplitter
from src.data.splitter import DataSplitter, TabularSplitter
from src.labelling import Labeller
from src.logging import WandbConf

Expand All @@ -23,7 +23,11 @@ class BaseRelay:
seed: int = 0

options: ClassVar[dict[str, dict[str, type]]] = {
"split": {"random": RandomSplitter, "artifact": SplitFromArtifact}
"split": {
"random": RandomSplitter,
"artifact": SplitFromArtifact,
"tabular": TabularSplitter,
}
}

def init_dm(
Expand Down
2 changes: 2 additions & 0 deletions src/relay/fs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from src.arch.backbones import DenseNet, LinearResNet, ResNet, SimpleCNN
from src.arch.predictors.fcn import Fcn
from src.data import DatasetFactory, NICOPPCfg, NIHChestXRayDatasetCfg
from src.data.factories import ACSCfg
from src.hydra_confs.datasets import Camelyon17Cfg, CelebACfg, ColoredMNISTCfg
from src.labelling.pipeline import (
CentroidalLabelNoiser,
Expand Down Expand Up @@ -44,6 +45,7 @@ class FsRelay(BaseRelay):

options: ClassVar[dict[str, dict[str, type]]] = BaseRelay.options | {
"ds": {
"acs": ACSCfg,
"cmnist": ColoredMNISTCfg,
"celeba": CelebACfg,
"camelyon17": Camelyon17Cfg,
Expand Down
2 changes: 2 additions & 0 deletions src/relay/label.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from attrs import define, field

from src.data.common import DatasetFactory
from src.data.factories import ACSCfg
from src.data.nih import NIHChestXRayDatasetCfg
from src.data.utils import resolve_device
from src.hydra_confs.datasets import Camelyon17Cfg, CelebACfg, ColoredMNISTCfg
Expand Down Expand Up @@ -31,6 +32,7 @@ class LabelRelay(BaseRelay):

options: ClassVar[dict[str, dict[str, type]]] = BaseRelay.options | {
"ds": {
"acs": ACSCfg,
"cmnist": ColoredMNISTCfg,
"celeba": CelebACfg,
"camelyon17": Camelyon17Cfg,
Expand Down
Loading

0 comments on commit 96c5d61

Please sign in to comment.