Skip to content

Commit

Permalink
Re-enable datamodule regression tests
Browse files Browse the repository at this point in the history
Signed-off-by: Fabrice Normandin <[email protected]>
  • Loading branch information
lebrice committed Sep 11, 2024
1 parent c7ccdb5 commit 9a16967
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 8 deletions.
10 changes: 8 additions & 2 deletions project/datamodules/datamodules_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from lightning import LightningDataModule
from lightning.fabric.utilities.exceptions import MisconfigurationException
from lightning.pytorch.trainer.states import RunningStage
from omegaconf import DictConfig
from tensor_regression.fixture import (
TensorRegressionFixture,
get_test_source_and_temp_file_paths,
Expand All @@ -18,16 +19,21 @@
ImageClassificationDataModule,
)
from project.datamodules.vision import VisionDataModule
from project.experiment import instantiate_datamodule
from project.utils.env_vars import REPO_ROOTDIR
from project.utils.testutils import run_for_all_datamodules
from project.utils.typing_utils import is_sequence_of

logger = logging.getLogger(__name__)


# @use_overrides(["datamodule.num_workers=0"])
@pytest.fixture
def datamodule(experiment_dictconfig: DictConfig):
return instantiate_datamodule(experiment_dictconfig.datamodule)


# @pytest.mark.timeout(25, func_only=True)
@pytest.mark.slow
# @use_overrides(["datamodule.num_workers=0"])
@pytest.mark.parametrize(
"stage",
[
Expand Down
7 changes: 1 addition & 6 deletions project/utils/testutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,6 @@ def _parametrized_fixture_method(request: pytest.FixtureRequest):


def run_for_all_datamodules(
datamodule_names: list[str] | None = None,
datamodule_name_to_marks: dict[str, pytest.MarkDecorator | list[pytest.MarkDecorator]]
| None = None,
):
Expand All @@ -162,16 +161,12 @@ def run_for_all_datamodules(
Parameters
----------
datamodule_names: List of datamodule names to use for tests. \
By default, lists out the generic datamodules (the datamodules that aren't specific to a
single algorithm, for example the InfGendatamodules of WakeSleep.)
datamodule_to_marks: Dictionary from datamodule names to pytest marks (e.g. \
`pytest.mark.xfail`, `pytest.mark.skip`) to use for that particular datamodule.
"""
return run_for_all_configs_in_group(
group_name="datamodule",
config_name_to_marks=datamodule_name_to_marks,
group_name="datamodule", config_name_to_marks=datamodule_name_to_marks
)


Expand Down

0 comments on commit 9a16967

Please sign in to comment.