Skip to content

Commit

Permalink
Big spring clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
tmke8 committed Feb 23, 2024
1 parent e05fe71 commit 34824c5
Show file tree
Hide file tree
Showing 22 changed files with 155 additions and 213 deletions.
43 changes: 42 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ on:
- main

jobs:
format_with_black:
format_with_ruff:

runs-on: ubuntu-latest

Expand Down Expand Up @@ -40,3 +40,44 @@ jobs:
- name: Lint with ruff
run: |
ruff check --output-format=github .
run_type_checking:
needs:
- format_with_black
- lint_with_ruff
runs-on: ubuntu-latest

steps:
# ----------------------------------------------
# ---- check-out repo and set-up python ----
# ----------------------------------------------
- name: Check out repository
uses: actions/checkout@v3
# ----------------------------------------------
# ----- install & configure poetry -----
# ----------------------------------------------
- name: Install poetry
run: pipx install poetry
- name: Set up Python 3.10
uses: actions/setup-python@v4
with:
python-version: '3.10'
cache: 'poetry'

# ----------------------------------------------
# install dependencies if cache does not exist
# ----------------------------------------------
- name: Install dependencies
run: |
poetry env use 3.10
poetry install --no-interaction --no-root --without torch
- name: Set python path for all subsequent actions
run: echo "$(poetry env info --path)/bin" >> $GITHUB_PATH

# ----------------------------------------------
# ----- install and run pyright -----
# ----------------------------------------------
- uses: jakebailey/pyright-action@v1
with:
# don't show warnings
level: error
110 changes: 1 addition & 109 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,6 @@ torchvision = ">=0.15.2"

[tool.poetry.group.dev.dependencies]
ruff = "*"
mypy = "*"
pytest = "*"
types-tqdm = "*"
pandas-stubs = "*"

Expand Down Expand Up @@ -124,6 +122,7 @@ reportUnknownLambdaType = "none"
reportUnknownVariableType = "none"
reportUnknownMemberType = "none"
reportMissingTypeArgument = "none"
reportUnnecessaryCast = "warning"
reportUnnecessaryTypeIgnoreComment = "warning"
exclude = [
"outputs",
Expand Down
25 changes: 11 additions & 14 deletions src/algs/adv/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from conduit.data import TernarySample
from conduit.data.datasets import CdtDataLoader, CdtDataset
from conduit.data.datasets.vision import CdtVisionDataset
from loguru import logger
from matplotlib import pyplot as plt
from matplotlib.colors import ListedColormap
Expand All @@ -22,14 +21,12 @@
from src.arch.predictors import Fcn
from src.data import DataModule, Dataset, group_id_to_label, labels_to_group_id, resolve_device
from src.evaluation.metrics import EmEvalPair, compute_metrics
from src.logging import log_images
from src.models import Classifier, OptimizerCfg, SplitLatentAe

__all__ = [
"Evaluator",
"InvariantDatasets",
"encode_dataset",
"log_sample_images",
"visualize_clusters",
]

Expand All @@ -51,17 +48,17 @@ class InvariantDatasets(Generic[DY, DS]):
zy: DS


def log_sample_images(
*,
data: CdtVisionDataset[TernarySample[Tensor], Tensor, Tensor],
dm: DataModule,
name: str,
step: int,
num_samples: int = 64,
) -> None:
inds: list[int] = torch.randperm(len(data))[:num_samples].tolist()
images = data[inds]
log_images(images=images, dm=dm, name=f"Samples from {name}", prefix="eval", step=step)
# def log_sample_images(
# *,
# data: CdtVisionDataset[TernarySample[Tensor], Tensor, Tensor],
# dm: DataModule,
# name: str,
# step: int,
# num_samples: int = 64,
# ) -> None:
# inds: list[int] = torch.randperm(len(data))[:num_samples].tolist()
# images = data[inds]
# log_images(images=images, dm=dm, name=f"Samples from {name}", prefix="eval", step=step)


InvariantAttr = Literal["zy", "zs", "both"]
Expand Down
6 changes: 3 additions & 3 deletions src/algs/adv/supmatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,9 @@ def _get_data_iterators(self, dm: DataModule) -> tuple[IterTr, IterDep]:
dl_tr = dm.train_dataloader(balance=True)
# The batch size needs to be consistent for the aggregation layer in the setwise neural
# discriminator
batch_size: int = dl_tr.batch_sampler.batch_size # type: ignore
dl_dep = dm.deployment_dataloader(
batch_size=dl_tr.batch_sampler.batch_size
if dm.deployment_ids is None
else dm.batch_size_tr
batch_size=batch_size if dm.deployment_ids is None else dm.batch_size_tr
)
return iter(dl_tr), iter(dl_dep)

Expand Down Expand Up @@ -161,6 +160,7 @@ def fit_evaluate_score(
disc_model_sd0 = None
if isinstance(disc, NeuralDiscriminator) and isinstance(disc.model, SetPredictor):
disc_model_sd0 = disc.model.state_dict()
assert isinstance(disc, Model)
super().fit_and_evaluate(dm=dm, ae=ae, disc=disc, evaluator=evaluator)
# TODO: Generalise this to other discriminator types and architectures
if disc_model_sd0 is not None:
Expand Down
8 changes: 2 additions & 6 deletions src/algs/fs/gdro.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,18 +232,14 @@ def update_stats(
self.avg_acc = group_frac @ self.avg_group_acc


@dataclass
class _LcMixin:
@dataclass(kw_only=True, repr=False, eq=False, frozen=True)
class GdroClassifier(Classifier):
loss_computer: LossComputer


@dataclass(repr=False, eq=False)
class GdroClassifier(Classifier, _LcMixin):
def __post_init__(self) -> None:
# LossComputer requires that the criterion return per-sample (unreduced) losses.
if self.criterion is not None:
self.criterion.reduction = ReductionType.none
super().__post_init__()

@override
def training_step(self, batch: TernarySample[Tensor], *, pred_s: bool = False) -> Tensor:
Expand Down
30 changes: 13 additions & 17 deletions src/algs/fs/lff.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from collections.abc import Iterator
from dataclasses import dataclass, field
from dataclasses import dataclass
from functools import cached_property
from typing import Any, TypeVar, Union
from typing_extensions import Self, override

Expand Down Expand Up @@ -98,24 +99,20 @@ def __len__(self) -> int:
return len(self.dataset)


@dataclass
class _LabelEmaMixin:
@dataclass(kw_only=True, repr=False, eq=False, frozen=True)
class LfFClassifier(Classifier):
criterion: CrossEntropyLoss
sample_loss_ema_b: LabelEma
sample_loss_ema_d: LabelEma


@dataclass(repr=False, eq=False)
class LfFClassifier(Classifier, _LabelEmaMixin):
q: float = 0.7
biased_model: nn.Module = field(init=False)
biased_criterion: GeneralizedCELoss = field(init=False)
criterion: CrossEntropyLoss = field(init=False)

def __post_init__(self) -> None:
self.biased_model = gcopy(self.model, deep=True)
self.biased_criterion = GeneralizedCELoss(q=self.q, reduction="mean")
self.criterion = CrossEntropyLoss(reduction="mean")
super().__post_init__()
@cached_property
def biased_model(self) -> nn.Module:
return gcopy(self.model, deep=True)

@cached_property
def biased_criterion(self) -> GeneralizedCELoss:
return GeneralizedCELoss(q=self.q, reduction="mean")

def training_step(self, batch: IndexedSample[Tensor], *, pred_s: bool = False) -> Tensor: # type: ignore
logit_b = self.biased_model(batch.x)
Expand Down Expand Up @@ -158,14 +155,13 @@ def routine(self, dm: DataModule, *, model: nn.Module) -> EvalTuple[Tensor, None
sample_loss_ema_d = LabelEma(dm.train.y, alpha=self.alpha).to(self.device)
dm.train = IndexedDataset(dm.train) # type: ignore
classifier = LfFClassifier(
criterion=CrossEntropyLoss(reduction="mean"),
sample_loss_ema_b=sample_loss_ema_b,
sample_loss_ema_d=sample_loss_ema_d,
model=model,
opt=self.opt,
q=self.q,
)
classifier.sample_loss_ema_b = sample_loss_ema_b
classifier.sample_loss_ema_d = sample_loss_ema_d
classifier.fit(
train_data=dm.train_dataloader(),
test_data=dm.test_dataloader(),
Expand Down
4 changes: 2 additions & 2 deletions src/arch/autoencoder/vqgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def __init__(

# end
self.norm_out = Normalize(block_in)
flattened_size = np.prod((block_in, curr_res, curr_res))
flattened_size = np.prod((block_in, curr_res, curr_res)).item()
self.to_latent = nn.Sequential(
nn.Flatten(),
nn.Linear(flattened_size, out_features=latent_dim),
Expand Down Expand Up @@ -253,7 +253,7 @@ def __init__(
curr_res = resolution // 2 ** (self.num_resolutions - 1)
unflattened_size = (block_in, curr_res, curr_res)
self.from_latent = nn.Sequential(
nn.Linear(latent_dim, np.prod(unflattened_size)),
nn.Linear(latent_dim, np.prod(unflattened_size).item()),
nn.Unflatten(dim=1, unflattened_size=unflattened_size),
)

Expand Down
2 changes: 1 addition & 1 deletion src/arch/backbones/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def __call__(self, input_dim: int) -> BackboneFactoryOut["tm.SwinTransformer"]:
model: "tm.SwinTransformer" = timm.create_model(
self.version.value, pretrained=self.pretrained, checkpoint_path=self.checkpoint_path
)
model.head = nn.Identity()
model.head = nn.Identity() # type: ignore
return model, model.num_features


Expand Down
Loading

0 comments on commit 34824c5

Please sign in to comment.