From 9a16967c23095f0eba149545ee652867f65267df Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Wed, 11 Sep 2024 21:04:02 +0000 Subject: [PATCH] Re-enable datamodule regression tests Signed-off-by: Fabrice Normandin --- project/datamodules/datamodules_test.py | 10 ++++++++-- project/utils/testutils.py | 7 +------ 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/project/datamodules/datamodules_test.py b/project/datamodules/datamodules_test.py index 31223752..6dd5648d 100644 --- a/project/datamodules/datamodules_test.py +++ b/project/datamodules/datamodules_test.py @@ -8,6 +8,7 @@ from lightning import LightningDataModule from lightning.fabric.utilities.exceptions import MisconfigurationException from lightning.pytorch.trainer.states import RunningStage +from omegaconf import DictConfig from tensor_regression.fixture import ( TensorRegressionFixture, get_test_source_and_temp_file_paths, @@ -18,6 +19,7 @@ ImageClassificationDataModule, ) from project.datamodules.vision import VisionDataModule +from project.experiment import instantiate_datamodule from project.utils.env_vars import REPO_ROOTDIR from project.utils.testutils import run_for_all_datamodules from project.utils.typing_utils import is_sequence_of @@ -25,9 +27,13 @@ logger = logging.getLogger(__name__) -# @use_overrides(["datamodule.num_workers=0"]) +@pytest.fixture +def datamodule(experiment_dictconfig: DictConfig): + return instantiate_datamodule(experiment_dictconfig.datamodule) + + # @pytest.mark.timeout(25, func_only=True) -@pytest.mark.slow +# @use_overrides(["datamodule.num_workers=0"]) @pytest.mark.parametrize( "stage", [ diff --git a/project/utils/testutils.py b/project/utils/testutils.py index d730798f..2a84bbf6 100644 --- a/project/utils/testutils.py +++ b/project/utils/testutils.py @@ -151,7 +151,6 @@ def _parametrized_fixture_method(request: pytest.FixtureRequest): def run_for_all_datamodules( - datamodule_names: list[str] | None = None, datamodule_name_to_marks: dict[str, pytest.MarkDecorator | list[pytest.MarkDecorator]] | None = None, ): @@ -162,16 +161,12 @@ def run_for_all_datamodules( Parameters ---------- - datamodule_names: List of datamodule names to use for tests. \ - By default, lists out the generic datamodules (the datamodules that aren't specific to a - single algorithm, for example the InfGendatamodules of WakeSleep.) datamodule_to_marks: Dictionary from datamodule names to pytest marks (e.g. \ `pytest.mark.xfail`, `pytest.mark.skip`) to use for that particular datamodule. """ return run_for_all_configs_in_group( - group_name="datamodule", - config_name_to_marks=datamodule_name_to_marks, + group_name="datamodule", config_name_to_marks=datamodule_name_to_marks )