Skip to content

Commit

Permalink
Automatically create schemas for yaml config used by Hydra (#7)
Browse files Browse the repository at this point in the history
* WIP: Automatically create yaml schemas for configs

Signed-off-by: Fabrice Normandin <[email protected]>

* Working (mostly)!

Signed-off-by: Fabrice Normandin <[email protected]>

* Add a .gitignore entry for schemas

Signed-off-by: Fabrice Normandin <[email protected]>

* Associate files to schemas via vscode settings

Signed-off-by: Fabrice Normandin <[email protected]>

* Improve descriptions

Signed-off-by: Fabrice Normandin <[email protected]>

* Update notes / todos

Signed-off-by: Fabrice Normandin <[email protected]>

* Fix doctest

Signed-off-by: Fabrice Normandin <[email protected]>

* Save progress so far (lots more to do :()

Signed-off-by: Fabrice Normandin <[email protected]>

* Ignore most hydra errors for now

Signed-off-by: Fabrice Normandin <[email protected]>

* Fix issues with logging verbosity

Signed-off-by: Fabrice Normandin <[email protected]>

* Rename functions and change signature

Signed-off-by: Fabrice Normandin <[email protected]>

* Make it possible to change only one file

Signed-off-by: Fabrice Normandin <[email protected]>

* Add "format on save" option to devcontainer

Signed-off-by: Fabrice Normandin <[email protected]>

* Fix bug in experiment/example.yaml

Signed-off-by: Fabrice Normandin <[email protected]>

* Format a yaml file

Signed-off-by: Fabrice Normandin <[email protected]>

* Regenerate schema if config changes

Signed-off-by: Fabrice Normandin <[email protected]>

* Remove unused imports in example.py

Signed-off-by: Fabrice Normandin <[email protected]>

* Add nice schema for `defaults` list!

Signed-off-by: Fabrice Normandin <[email protected]>

* Fix issue with unit test for auto schema

Signed-off-by: Fabrice Normandin <[email protected]>

* Add a 'features/auto_schema.md' file with a demo

Signed-off-by: Fabrice Normandin <[email protected]>

* Create schemas inside `main`, tweak defaults

Signed-off-by: Fabrice Normandin <[email protected]>

* Handle sub-entries with _target_

Signed-off-by: Fabrice Normandin <[email protected]>

* Fix the errors in the test configs for auto_shema

Signed-off-by: Fabrice Normandin <[email protected]>

* Remove absolute path in regression files

Signed-off-by: Fabrice Normandin <[email protected]>

* Add tests to increase code coverage

Signed-off-by: Fabrice Normandin <[email protected]>

* Fix issue when writing yaml schema header

Signed-off-by: Fabrice Normandin <[email protected]>

* Add test to cover changes in `main.py`

Signed-off-by: Fabrice Normandin <[email protected]>

* "fix" type-checking errors in auto_schema.py

Signed-off-by: Fabrice Normandin <[email protected]>

---------

Signed-off-by: Fabrice Normandin <[email protected]>
  • Loading branch information
lebrice authored Aug 26, 2024
1 parent eec1563 commit 2bb1a5d
Show file tree
Hide file tree
Showing 57 changed files with 2,196 additions and 371 deletions.
7 changes: 6 additions & 1 deletion .devcontainer/devcontainer.json
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@
"editor.defaultFormatter": "charliermarsh.ruff",
"editor.formatOnType": true
},
"editor.rulers": [
99
],
"editor.formatOnSave": true,
"files.exclude": {
"**/.git": true,
"**/.svn": true,
Expand Down Expand Up @@ -61,7 +65,8 @@
"GitHub.copilot",
"knowsuchagency.pdm-task-provider",
"GitHub.copilot-chat",
"mutantdino.resourcemonitor"
"mutantdino.resourcemonitor",
"Gruntfuggly.todo-tree"
]
}
},
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -169,3 +169,4 @@ lightning_logs
**.npz
# .python-version
.testmondata*
.schemas
2 changes: 2 additions & 0 deletions docs/SUMMARY.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
* overview/*.md
* Getting Started
* getting_started/*.md
* Features
* features/*.md
* Reference
* reference/*
* Examples
Expand Down
18 changes: 18 additions & 0 deletions docs/features/auto_schema.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Auto Schema for Hydra Configs

This project template comes with a really neat feature: Your [Hydra](https://hydra.cc) config files automatically get a [Schema](https://json-schema.org/) associated with them.

This greatly improves the experience of developing a project with Hydra:

- Saves you time by preventing errors caused by unexpected keys in your config files, or values that are of the wrong type
This can often happen after moving files or renaming a function, for example.
- While writing a config file you get to see:
- the list of available configuration options in a given config
- the default values for each value
- the documentation for each value (taken from the source code of the function!)

Here's a quick demo of what this looks like in practice:

![type:video](https://github.com/user-attachments/assets/08f52d47-ebba-456d-95ef-ac9525d8e983)

Here we have a config that will be used to configure the `lightning.Trainer` class, but any config file in the project will also get a schema automatically, even if it doesn't have a `"_target_"` key directly!
58 changes: 32 additions & 26 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,32 +14,38 @@ markdown_extensions:
- pymdownx.superfences

plugins:
- search
- literate-nav:
nav_file: SUMMARY.md
- awesome-pages
- gen-files:
# https://oprypin.github.io/mkdocs-gen-files/#usage
scripts:
- docs/generate_reference_docs.py
- mkdocstrings:
handlers:
python:
import:
- https://docs.python-requests.org/en/master/objects.inv
options:
docstring_style: google
members_order: source
annotations_path: brief
show_docstring_attributes: true
modernize_annotations: true
show_source: false
show_submodules: false
separate_signature: true
signature_crossrefs: true
show_signature_annotations: true
allow_inspection: true

- search
- literate-nav:
nav_file: SUMMARY.md
- awesome-pages
- gen-files:
# https://oprypin.github.io/mkdocs-gen-files/#usage
scripts:
- docs/generate_reference_docs.py
- mkdocstrings:
handlers:
python:
import:
- https://docs.python-requests.org/en/master/objects.inv
- https://lightning.ai/docs/pytorch/stable/objects.inv
options:
docstring_style: google
members_order: source
annotations_path: brief
show_docstring_attributes: true
modernize_annotations: true
show_source: false
show_submodules: false
separate_signature: true
signature_crossrefs: true
show_signature_annotations: true
allow_inspection: true
- mkdocs-video:
is_video: True
video_muted: True
video_controls: True
css_style:
width: "100%"
# todo: take a look at https://github.com/drivendataorg/cookiecutter-data-science/blob/master/docs/mkdocs.yml
# - admonition
# - pymdownx.details
Expand Down
2 changes: 1 addition & 1 deletion project/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from .utils.hydra_utils import patched_safe_name # noqa

# from .networks import FcNet
from .utils.types import DataModule
from .utils.typing_utils import DataModule

add_configs_to_hydra_store()

Expand Down
2 changes: 1 addition & 1 deletion project/algorithms/callbacks/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from lightning import pytorch as pl
from typing_extensions import TypeVar, override

from project.utils.types import NestedMapping
from project.utils.typing_utils import NestedMapping
from project.utils.utils import get_log_dir

logger = get_logger(__name__)
Expand Down
2 changes: 1 addition & 1 deletion project/algorithms/callbacks/classification_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from typing_extensions import NotRequired, Required, override

from project.algorithms.callbacks.callback import BatchType, Callback
from project.utils.types.protocols import ClassificationDataModule
from project.utils.typing_utils.protocols import ClassificationDataModule

logger = get_logger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion project/algorithms/callbacks/samples_per_second.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing_extensions import override

from project.algorithms.callbacks.callback import BatchType, Callback, StepOutputType
from project.utils.types import is_sequence_of
from project.utils.typing_utils import is_sequence_of


class MeasureSamplesPerSecondCallback(Callback[BatchType, StepOutputType]):
Expand Down
3 changes: 1 addition & 2 deletions project/algorithms/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,11 @@
from typing import Any, Literal

import torch
from hydra_zen.typing import HydraPartialBuilds, Partial, PartialBuilds, ZenPartialBuilds # noqa
from lightning import LightningModule
from omegaconf import DictConfig
from torch import Tensor
from torch.nn import functional as F
from torch.optim import Optimizer
from torch.optim.optimizer import Optimizer

from project.configs.algorithm.optimizer import AdamConfig
from project.datamodules.image_classification import ImageClassificationDataModule
Expand Down
2 changes: 1 addition & 1 deletion project/algorithms/jax_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
ImageClassificationDataModule,
)
from project.datamodules.image_classification.mnist import MNISTDataModule
from project.utils.types.protocols import ClassificationDataModule
from project.utils.typing_utils.protocols import ClassificationDataModule

os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

Expand Down
2 changes: 1 addition & 1 deletion project/algorithms/no_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torch import nn

from project.algorithms.callbacks.samples_per_second import MeasureSamplesPerSecondCallback
from project.utils.types.protocols import DataModule
from project.utils.typing_utils.protocols import DataModule


class NoOp(LightningModule):
Expand Down
4 changes: 2 additions & 2 deletions project/algorithms/testsuites/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from torch import Tensor
from typing_extensions import NotRequired, TypeVar

from project.utils.types import PyTree
from project.utils.types.protocols import DataModule, Module
from project.utils.typing_utils import PyTree
from project.utils.typing_utils.protocols import DataModule, Module


class StepOutputDict(TypedDict, total=False):
Expand Down
4 changes: 2 additions & 2 deletions project/algorithms/testsuites/algorithm_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
ParametrizedFixture,
seeded_rng,
)
from project.utils.types import PyTree, is_sequence_of
from project.utils.types.protocols import DataModule
from project.utils.typing_utils import PyTree, is_sequence_of
from project.utils.typing_utils.protocols import DataModule

logger = get_logger(__name__)

Expand Down
1 change: 0 additions & 1 deletion project/configs/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,4 @@ defaults:
# experiment configs allow for version control of specific hyperparameters
# e.g. best hyperparameters for given model and datamodule
- experiment: null

# name: "${hydra:runtime.choices.algorithm}-${hydra:runtime.choices.network}-${hydra:runtime.choices.datamodule}"
3 changes: 2 additions & 1 deletion project/configs/datamodule/cifar10.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
defaults:
- vision
- vision
- _self_
_target_: project.datamodules.CIFAR10DataModule
data_dir: ${constant:torchvision_dir,DATA_DIR}
batch_size: 128
Expand Down
3 changes: 2 additions & 1 deletion project/configs/datamodule/fashion_mnist.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
defaults:
- mnist
- mnist
- _self_
_target_: project.datamodules.FashionMNISTDataModule
1 change: 1 addition & 0 deletions project/configs/datamodule/imagenet.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
defaults:
- vision
- _self_
_target_: project.datamodules.ImageNetDataModule
# todo: add good configuration options here.
3 changes: 2 additions & 1 deletion project/configs/datamodule/imagenet32.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
defaults:
- vision
- vision
- _self_
_target_: project.datamodules.ImageNet32DataModule
data_dir: ${constant:SCRATCH}
val_split: -1
Expand Down
3 changes: 2 additions & 1 deletion project/configs/datamodule/inaturalist.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
defaults:
- vision
- vision
- _self_
_target_: project.datamodules.INaturalistDataModule
version: "2021_train"
target_type: "full"
3 changes: 2 additions & 1 deletion project/configs/datamodule/mnist.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
defaults:
- vision
- vision
- _self_
_target_: project.datamodules.MNISTDataModule
data_dir: ${constant:torchvision_dir,DATA_DIR}
normalize: True
Expand Down
13 changes: 6 additions & 7 deletions project/configs/experiment/cluster_sweep_example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,13 @@ hydra:
subdir: ${hydra.job.id}/task${oc.env:SLURM_PROCID}

launcher:
array_parallelism: 4 # max num of jobs to run in parallel
array_parallelism: 4 # max num of jobs to run in parallel
# note: Can't currently use this argument with the submitit launcher plugin. We would need to
# create a subclass to add support for it.
# ntasks_per_gpu: 2
# Other things to pass to `sbatch`:
additional_parameters:
time: 0-00:10:00 # maximum wall time allocated for the job (D-HH:MM:SS)
time: 0-00:10:00 # maximum wall time allocated for the job (D-HH:MM:SS)
# requeue: null # requeue job if it fails
## A list of commands to add to the generated sbatch script before running srun:
# setup:
Expand All @@ -61,12 +61,11 @@ hydra:
name: "${name}"
version: 1


algorithm:
# BUG: Getting a weird bug with TPE: KeyError in `dum_below_trials = [...]` at line 397.
# BUG: Getting a weird bug with TPE: KeyError in `dum_below_trials = [...]` at line 397.
type: random
config:
seed: 1
seed: 1

worker:
n_workers: ${hydra.launcher.array_parallelism}
Expand All @@ -77,5 +76,5 @@ hydra:
type: legacy
use_hydra_path: false
database:
type: pickleddb
host: "logs/${name}/multiruns/database.pkl"
type: pickleddb
host: "logs/${name}/multiruns/database.pkl"
6 changes: 2 additions & 4 deletions project/configs/experiment/example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,12 @@ name: example
seed: ${oc.env:SLURM_PROCID,12345}

algorithm:
hp:
optimizer:
lr: 0.002
optimizer_config:
lr: 0.002

datamodule:
batch_size: 64


trainer:
min_epochs: 1
max_epochs: 10
Expand Down
1 change: 0 additions & 1 deletion project/configs/trainer/default.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
_target_: lightning.Trainer

logger: null
accelerator: auto
strategy: auto
Expand Down
4 changes: 2 additions & 2 deletions project/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@
default_marks_for_config_name,
seeded_rng,
)
from project.utils.types import is_sequence_of
from project.utils.types.protocols import (
from project.utils.typing_utils import is_sequence_of
from project.utils.typing_utils.protocols import (
DataModule,
)

Expand Down
2 changes: 1 addition & 1 deletion project/datamodules/datamodules_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
)
from project.datamodules.vision import VisionDataModule
from project.utils.testutils import run_for_all_datamodules
from project.utils.types import is_sequence_of
from project.utils.typing_utils import is_sequence_of


# @use_overrides(["datamodule.num_workers=0"])
Expand Down
2 changes: 1 addition & 1 deletion project/datamodules/image_classification/cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
ImageClassificationDataModule,
)
from project.datamodules.vision import VisionDataModule
from project.utils.types import C, H, W
from project.utils.typing_utils import C, H, W


def cifar10_train_transforms():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from torchvision.tv_tensors import Image

from project.datamodules.vision import VisionDataModule
from project.utils.types import C, H, W
from project.utils.types.protocols import ClassificationDataModule
from project.utils.typing_utils import C, H, W
from project.utils.typing_utils.protocols import ClassificationDataModule

# todo: need to decide whether this should be a base class or just a protocol.
# - IF this is a protocol, then we can't use issubclass with it, so it can't be used in the
Expand Down
4 changes: 2 additions & 2 deletions project/datamodules/image_classification/imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@

from project.datamodules.vision import VisionDataModule
from project.utils.env_vars import DATA_DIR, NETWORK_DIR, NUM_WORKERS
from project.utils.types import C, H, W
from project.utils.types.protocols import Module
from project.utils.typing_utils import C, H, W
from project.utils.typing_utils.protocols import Module

logger = get_logger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion project/datamodules/image_classification/imagenet32.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from project.datamodules.vision import VisionDataModule
from project.utils.env_vars import DATA_DIR, SCRATCH
from project.utils.types import C, H, W
from project.utils.typing_utils import C, H, W

logger = getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion project/datamodules/image_classification/inaturalist.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
ImageClassificationDataModule,
)
from project.utils.env_vars import DATA_DIR, NUM_WORKERS, SLURM_TMPDIR
from project.utils.types import C, H, W
from project.utils.typing_utils import C, H, W

logger = get_logger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion project/datamodules/image_classification/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from project.datamodules.image_classification.image_classification import (
ImageClassificationDataModule,
)
from project.utils.types import C, H, W
from project.utils.typing_utils import C, H, W


def mnist_train_transforms():
Expand Down
Loading

0 comments on commit 2bb1a5d

Please sign in to comment.