Skip to content

Commit

Permalink
Add ACS dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
tmke8 committed Feb 27, 2024
1 parent 2efd6fa commit 12ca592
Show file tree
Hide file tree
Showing 15 changed files with 177 additions and 50 deletions.
10 changes: 10 additions & 0 deletions external_confs/ds/acs/employment_dis_fl.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
---
defaults:
- acs

setting: employment_disability
survey_year: YEAR_2018
states:
- FL
survey: PERSON
horizon: ONE_YEAR
8 changes: 8 additions & 0 deletions external_confs/split/acs/employment_dis.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
---
defaults:
- tabular

seed: 0
train_props:
1:
0: 0.0
34 changes: 27 additions & 7 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 6 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ scikit-image = ">=0.14"
scikit_learn = { version = ">=0.20.1" }
scipy = { version = ">=1.2.1" }
seaborn = { version = ">=0.9.0" }
torch-conduit = { version = ">=0.3.4", extras = ["image"] }
torch-conduit = { version = ">=0.3.4", extras = ["image", "fair"] }
typing_extensions = ">= 4.10"

tqdm = { version = ">=4.31.1" }
typer = "*"
Expand All @@ -56,7 +57,7 @@ torchvision = ">=0.15.2"
ruff = "*"
types-tqdm = "*"
pandas-stubs = "*"
python-type-stubs = {git = "https://github.com/wearepal/python-type-stubs.git", rev = "8d5f608"}
python-type-stubs = { git = "https://github.com/wearepal/python-type-stubs.git", rev = "8d5f608" }

[build-system]
requires = ["poetry-core>=1.0.0"]
Expand All @@ -67,6 +68,9 @@ target-version = "py310"
line-length = 100
extend-exclude = ["hydra_plugins"]

[tool.ruff.format]
quote-style = "preserve"

[tool.ruff.lint]
select = ["I", "F", "E", "W", "UP"]
ignore = [
Expand Down
10 changes: 5 additions & 5 deletions src/arch/autoencoder/vqgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(self, in_channels: int, *, with_conv: bool) -> None:
else:
self.conv = None

def forward(self, x: Tensor) -> Tensor: # type: ignore
def forward(self, x: Tensor) -> Tensor:
if self.conv is not None:
pad = (0, 1, 0, 1)
x = F.pad(x, pad, mode="constant", value=0)
Expand Down Expand Up @@ -81,7 +81,7 @@ def __init__(
in_channels, out_channels, kernel_size=1, stride=1, padding=0
)

def forward(self, x: Tensor) -> Tensor: # type: ignore
def forward(self, x: Tensor) -> Tensor:
h = x
h = self.norm1(h)
h = F.silu(h)
Expand Down Expand Up @@ -112,7 +112,7 @@ def __init__(self, in_channels: int):
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)

def forward(self, x: Tensor) -> Tensor: # type: ignore
def forward(self, x: Tensor) -> Tensor:
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
Expand Down Expand Up @@ -201,7 +201,7 @@ def __init__(
nn.Linear(flattened_size, out_features=latent_dim),
)

def forward(self, x: Tensor) -> Tensor: # type: ignore
def forward(self, x: Tensor) -> Tensor:
# timestep embedding
# downsampling
hs = [self.conv_in(x)]
Expand Down Expand Up @@ -288,7 +288,7 @@ def __init__(
self.norm_out = Normalize(block_in)
self.conv_out = nn.Conv2d(block_in, out_channels, kernel_size=3, stride=1, padding=1)

def forward(self, z: Tensor) -> Tensor: # type: ignore
def forward(self, z: Tensor) -> Tensor:
# z to block_in
h = self.from_latent(z)

Expand Down
2 changes: 1 addition & 1 deletion src/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .common import *
from .data_module import *
from .nico_plus_plus import *
from .factories import *
from .nih import *
from .splitter import *
from .utils import *
3 changes: 1 addition & 2 deletions src/data/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from conduit.data import LoadedData, TernarySample, UnloadedData
from conduit.data.datasets import CdtDataset
from hydra.utils import to_absolute_path
from numpy import typing as npt
from torch import Tensor

__all__ = [
Expand Down Expand Up @@ -78,5 +77,5 @@ def num_samples_te(self) -> int:

class DatasetFactory(ABC):
@abstractmethod
def __call__(self) -> Dataset[npt.NDArray]:
def __call__(self) -> Dataset:
raise NotImplementedError()
50 changes: 50 additions & 0 deletions src/data/factories.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
"""Dataset factories."""

from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Optional, Union
from typing_extensions import override

from conduit.data.datasets.vision import NICOPP, NicoPPTarget
from conduit.fair.data.datasets import (
ACSDataset,
ACSHorizon,
ACSSetting,
ACSState,
ACSSurvey,
ACSSurveyYear,
)

from src.data.common import DatasetFactory

__all__ = ["NICOPPCfg"]


@dataclass
class NICOPPCfg(DatasetFactory):
root: Union[Path, str]
target_attrs: Optional[list[NicoPPTarget]] = None
transform: Any = None # Optional[Union[Compose, BasicTransform, Callable[[Image], Any]]]

@override
def __call__(self) -> NICOPP:
return NICOPP(root=self.root, transform=self.transform, superclasses=self.target_attrs)


@dataclass
class ACSCfg(DatasetFactory):
setting: ACSSetting
survey_year: ACSSurveyYear = ACSSurveyYear.YEAR_2018
horizon: ACSHorizon = ACSHorizon.ONE_YEAR
survey: ACSSurvey = ACSSurvey.PERSON
states: list[ACSState] = field(default_factory=lambda: [ACSState.AL])

@override
def __call__(self) -> ACSDataset:
return ACSDataset(
setting=self.setting,
survey_year=self.survey_year,
horizon=self.horizon,
survey=self.survey,
states=self.states,
)
24 changes: 0 additions & 24 deletions src/data/nico_plus_plus.py

This file was deleted.

61 changes: 56 additions & 5 deletions src/data/splitter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from abc import abstractmethod
from abc import ABC, abstractmethod
from collections.abc import Callable
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
import platform
from tempfile import TemporaryDirectory
Expand All @@ -10,7 +12,10 @@
from conduit.data.datasets import random_split
from conduit.data.datasets.utils import stratified_split
from conduit.data.datasets.vision import CdtVisionDataset, ImageTform, PillowTform
from conduit.fair.data.datasets import ACSDataset
from conduit.transforms import MinMaxNormalize, TabularNormalize, ZScoreNormalize
from loguru import logger
from ranzen import some
import torch
from torch import Tensor
import torchvision.transforms as T
Expand All @@ -24,14 +29,22 @@
"DataSplitter",
"RandomSplitter",
"SplitFromArtifact",
"TabularSplitter",
"load_split_inds_from_artifact",
"save_split_inds_as_artifact",
]


@dataclass(eq=False)
class DataSplitter:
"""How to split the data into train/test/dep."""
class DataSplitter(ABC):
@abstractmethod
def __call__(self, dataset: D) -> TrainDepTestSplit[D]:
"""Split the dataset into train/deployment/test."""


@dataclass(eq=False)
class _VisionDataSplitter(DataSplitter):
"""Common methods for transforming vision datasets."""

transductive: bool = False
"""Whether to include the test data in the pool of unlabelled data."""
Expand Down Expand Up @@ -133,7 +146,7 @@ def save_split_inds_as_artifact(


@dataclass(eq=False)
class RandomSplitter(DataSplitter):
class RandomSplitter(_VisionDataSplitter):
seed: int = 42
dep_prop: float = 0.4
test_prop: float = 0.2
Expand Down Expand Up @@ -259,7 +272,7 @@ def load_split_inds_from_artifact(


@dataclass(eq=False, kw_only=True)
class SplitFromArtifact(DataSplitter):
class SplitFromArtifact(_VisionDataSplitter):
artifact_name: str
version: Optional[int] = None

Expand All @@ -272,3 +285,41 @@ def split(self, dataset: D) -> TrainDepTestSplit[D]:
dep_data = dataset.subset(splits["dep"])
test_data = dataset.subset(splits["test"])
return TrainDepTestSplit(train=train_data, deployment=dep_data, test=test_data)


class TabularTform(Enum):
zscore_normalize = (ZScoreNormalize,)
minmax_normalize = (MinMaxNormalize,)

def __init__(self, tform: Callable[[], TabularNormalize]) -> None:
self.tf = tform


@dataclass(eq=False)
class TabularSplitter(DataSplitter):
"""Split and transform tabular datasets."""

seed: int
train_props: dict[int, dict[int, float]] | None = None
dep_prop: float = 0.2
test_prop: float = 0.1
transform: TabularTform | None = TabularTform.zscore_normalize

@override
def __call__(self, dataset: D) -> TrainDepTestSplit[D]:
if not isinstance(dataset, ACSDataset):
raise NotImplementedError("TabularSplitter only supports splitting of `ACSDataset`.")

train, dep, test = dataset.subsampled_split(
train_props=None,
val_prop=self.dep_prop,
test_prop=self.test_prop,
seed=self.seed,
)
if some(tf_type := self.transform):
tf = tf_type.tf()
train.fit_transform_(tf)
dep.transform_(tf)
test.transform_(tf)

return TrainDepTestSplit(train=train, deployment=dep, test=test)
8 changes: 6 additions & 2 deletions src/relay/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from src.data import DataModule, DataModuleConf, RandomSplitter, SplitFromArtifact
from src.data.common import Dataset
from src.data.splitter import DataSplitter
from src.data.splitter import DataSplitter, TabularSplitter
from src.labelling import Labeller
from src.logging import WandbConf

Expand All @@ -23,7 +23,11 @@ class BaseRelay:
seed: int = 0

options: ClassVar[dict[str, dict[str, type]]] = {
"split": {"random": RandomSplitter, "artifact": SplitFromArtifact}
"split": {
"random": RandomSplitter,
"artifact": SplitFromArtifact,
"tabular": TabularSplitter,
}
}

def init_dm(
Expand Down
Loading

0 comments on commit 12ca592

Please sign in to comment.