From fcb3188f894e938cf2e068e077f6cf9d24803ff6 Mon Sep 17 00:00:00 2001 From: "K.Filippopolitis" <56073635+KFilippopolitis@users.noreply.github.com> Date: Mon, 22 Jul 2024 13:51:24 +0300 Subject: [PATCH] Split datasets to training datasets and validation datasets. (#490) --- .../workflows/algorithm_validation_tests.yml | 8 + .../flower/inputdata_preprocessing.py | 3 +- .../flower/logistic_regression.json | 3 +- exareme2/algorithms/specifications.py | 1 + .../services/api/algorithm_request_dtos.py | 1 + .../api/algorithm_request_validator.py | 93 ++++++++--- .../services/api/algorithm_spec_dtos.py | 19 +++ .../controller/services/flower/controller.py | 16 +- .../worker_landscape_aggregator.py | 37 ++++- .../flower/test_logistic_regression.py | 4 +- .../flower/test_mnist_logistic_regression.py | 2 - .../one_node_deployment_template.toml | 2 +- ...st_flower_logisticregression_validation.py | 2 +- .../flower/test_inputdata_preprocessing.py | 1 + .../flower/test_processes_garbage_collect.py | 2 +- .../api/test_validate_algorithm_request.py | 146 +++++++++++++++--- .../test_pos_and_kw_args_in_algorithm_flow.py | 2 +- 17 files changed, 283 insertions(+), 59 deletions(-) diff --git a/.github/workflows/algorithm_validation_tests.yml b/.github/workflows/algorithm_validation_tests.yml index b4f3e6b79..cebf1efbb 100644 --- a/.github/workflows/algorithm_validation_tests.yml +++ b/.github/workflows/algorithm_validation_tests.yml @@ -107,6 +107,9 @@ jobs: - name: Controller logs run: cat /tmp/exareme2/controller.out + - name: Globalworker logs + run: cat /tmp/exareme2/globalworker.out + - name: Localworker logs run: cat /tmp/exareme2/localworker1.out @@ -115,6 +118,11 @@ jobs: with: run: cat /tmp/exareme2/controller.out + - name: Globalworker logs (post run) + uses: webiny/action-post-run@3.0.0 + with: + run: cat /tmp/exareme2/globalworker.out + - name: Localworker logs (post run) uses: webiny/action-post-run@3.0.0 with: diff --git a/exareme2/algorithms/flower/inputdata_preprocessing.py b/exareme2/algorithms/flower/inputdata_preprocessing.py index e3891c388..2253aa8cb 100644 --- a/exareme2/algorithms/flower/inputdata_preprocessing.py +++ b/exareme2/algorithms/flower/inputdata_preprocessing.py @@ -24,6 +24,7 @@ class Inputdata(BaseModel): data_model: str datasets: List[str] + validation_datasets: List[str] filters: Optional[dict] y: Optional[List[str]] x: Optional[List[str]] @@ -32,7 +33,7 @@ class Inputdata(BaseModel): def apply_inputdata(df: pd.DataFrame, inputdata: Inputdata) -> pd.DataFrame: if inputdata.filters: df = apply_filter(df, inputdata.filters) - df = df[df["dataset"].isin(inputdata.datasets)] + df = df[df["dataset"].isin(inputdata.datasets + inputdata.validation_datasets)] columns = inputdata.x + inputdata.y df = df[columns] df = df.dropna(subset=columns) diff --git a/exareme2/algorithms/flower/logistic_regression.json b/exareme2/algorithms/flower/logistic_regression.json index 959063d3d..3685720f3 100644 --- a/exareme2/algorithms/flower/logistic_regression.json +++ b/exareme2/algorithms/flower/logistic_regression.json @@ -32,6 +32,7 @@ ], "notblank": true, "multiple": true - } + }, + "validation": true } } diff --git a/exareme2/algorithms/specifications.py b/exareme2/algorithms/specifications.py index 2480accbb..431b0c254 100644 --- a/exareme2/algorithms/specifications.py +++ b/exareme2/algorithms/specifications.py @@ -102,6 +102,7 @@ class InputDataSpecification(ImmutableBaseModel): class InputDataSpecifications(ImmutableBaseModel): y: InputDataSpecification x: Optional[InputDataSpecification] + validation: Optional[bool] class ParameterEnumSpecification(ImmutableBaseModel): diff --git a/exareme2/controller/services/api/algorithm_request_dtos.py b/exareme2/controller/services/api/algorithm_request_dtos.py index 7609b2a09..deebc6cfd 100644 --- a/exareme2/controller/services/api/algorithm_request_dtos.py +++ b/exareme2/controller/services/api/algorithm_request_dtos.py @@ -24,6 +24,7 @@ class Config: class AlgorithmInputDataDTO(ImmutableBaseModel): data_model: str datasets: List[str] + validation_datasets: Optional[List[str]] filters: Optional[dict] y: Optional[List[str]] x: Optional[List[str]] diff --git a/exareme2/controller/services/api/algorithm_request_validator.py b/exareme2/controller/services/api/algorithm_request_validator.py index a0d1753d1..805baaa87 100644 --- a/exareme2/controller/services/api/algorithm_request_validator.py +++ b/exareme2/controller/services/api/algorithm_request_validator.py @@ -51,24 +51,21 @@ def validate_algorithm_request( algorithm_name, algorithm_request_dto.type, algorithms_specs ) - available_datasets_per_data_model = ( - worker_landscape_aggregator.get_all_available_datasets_per_data_model() - ) - - _validate_data_model( - requested_data_model=algorithm_request_dto.inputdata.data_model, - available_datasets_per_data_model=available_datasets_per_data_model, + ( + training_datasets, + validation_datasets, + ) = worker_landscape_aggregator.get_train_and_validation_datasets( + algorithm_request_dto.inputdata.data_model ) - data_model_cdes = worker_landscape_aggregator.get_cdes( algorithm_request_dto.inputdata.data_model ) - _validate_algorithm_request_body( algorithm_request_dto=algorithm_request_dto, algorithm_specs=algorithm_specs, transformers_specs=transformers_specs, - available_datasets_per_data_model=available_datasets_per_data_model, + training_datasets=training_datasets, + validation_datasets=validation_datasets, data_model_cdes=data_model_cdes, smpc_enabled=smpc_enabled, smpc_optional=smpc_optional, @@ -89,15 +86,21 @@ def _validate_algorithm_request_body( algorithm_request_dto: AlgorithmRequestDTO, algorithm_specs: AlgorithmSpecification, transformers_specs: Dict[str, TransformerSpecification], - available_datasets_per_data_model: Dict[str, List[str]], + training_datasets: List[str], + validation_datasets: List[str], data_model_cdes: Dict[str, CommonDataElement], smpc_enabled: bool, smpc_optional: bool, ): + _ensure_validation_criteria( + algorithm_request_dto.inputdata.validation_datasets, + algorithm_specs.inputdata.validation, + ) _validate_inputdata( inputdata=algorithm_request_dto.inputdata, inputdata_specs=algorithm_specs.inputdata, - available_datasets_per_data_model=available_datasets_per_data_model, + training_datasets=training_datasets, + validation_datasets=validation_datasets, data_model_cdes=data_model_cdes, ) @@ -122,38 +125,58 @@ def _validate_algorithm_request_body( ) -def _validate_data_model(requested_data_model: str, available_datasets_per_data_model): - if requested_data_model not in available_datasets_per_data_model.keys(): - raise BadUserInput(f"Data model '{requested_data_model}' does not exist.") +def _ensure_validation_criteria(validation_datasets: List[str], validation: bool): + """ + Validates the input based on the provided validation flag and datasets. + + Parameters: + validation_datasets (List[str]): List of validation datasets. + validation (bool): Flag indicating if validation is required. + + Raises: + BadUserInput: If the input conditions are not met. + """ + if not validation and validation_datasets: + raise BadUserInput( + "Validation is false, but validation datasets were provided." + ) + elif validation and not validation_datasets: + raise BadUserInput( + "Validation is true, but no validation datasets were provided." + ) def _validate_inputdata( inputdata: AlgorithmInputDataDTO, inputdata_specs: InputDataSpecifications, - available_datasets_per_data_model: Dict[str, List[str]], + training_datasets: List[str], + validation_datasets: List[str], data_model_cdes: Dict[str, CommonDataElement], ): - _validate_inputdata_dataset( + _validate_inputdata_training_datasets( requested_data_model=inputdata.data_model, requested_datasets=inputdata.datasets, - available_datasets_per_data_model=available_datasets_per_data_model, + training_datasets=training_datasets, + ) + _validate_inputdata_validation_datasets( + requested_data_model=inputdata.data_model, + requested_validation_datasets=inputdata.validation_datasets, + validation_datasets=validation_datasets, ) _validate_inputdata_filter(inputdata.data_model, inputdata.filters, data_model_cdes) _validate_algorithm_inputdatas(inputdata, inputdata_specs, data_model_cdes) -def _validate_inputdata_dataset( +def _validate_inputdata_training_datasets( requested_data_model: str, requested_datasets: List[str], - available_datasets_per_data_model: Dict[str, List[str]], + training_datasets: List[str], ): """ - Validates that the dataset values exist and that the datasets belong in the data_model. + Validates that the dataset values exist and that the datasets. """ non_existing_datasets = [ - dataset - for dataset in requested_datasets - if dataset not in available_datasets_per_data_model[requested_data_model] + dataset for dataset in requested_datasets if dataset not in training_datasets ] if non_existing_datasets: raise BadUserInput( @@ -161,6 +184,28 @@ def _validate_inputdata_dataset( ) +def _validate_inputdata_validation_datasets( + requested_data_model: str, + requested_validation_datasets: List[str], + validation_datasets: List[str], +): + """ + Validates that the validation dataset values exist and that the validation_datasets. + """ + if not requested_validation_datasets: + return + + non_existing_datasets = [ + dataset + for dataset in requested_validation_datasets + if dataset not in validation_datasets + ] + if non_existing_datasets: + raise BadUserInput( + f"Validation Datasets:'{non_existing_datasets}' could not be found for data_model:{requested_data_model}" + ) + + def _validate_inputdata_filter(data_model, filter, data_model_cdes): """ Validates that the filter provided have the correct format diff --git a/exareme2/controller/services/api/algorithm_spec_dtos.py b/exareme2/controller/services/api/algorithm_spec_dtos.py index 998001376..c98b4a467 100644 --- a/exareme2/controller/services/api/algorithm_spec_dtos.py +++ b/exareme2/controller/services/api/algorithm_spec_dtos.py @@ -48,6 +48,7 @@ class InputDataSpecificationsDTO(ImmutableBaseModel): filter: InputDataSpecificationDTO y: InputDataSpecificationDTO x: Optional[InputDataSpecificationDTO] + validation_datasets: Optional[InputDataSpecificationDTO] class ParameterEnumSpecificationDTO(ImmutableBaseModel): @@ -121,6 +122,18 @@ def _get_data_model_input_data_specification_dto(): ) +def _get_valiadtion_datasets_input_data_specification_dto(): + return InputDataSpecificationDTO( + label="Set of data to validate.", + desc="The set of data to validate the algorithm model on.", + types=[InputDataType.TEXT], + notblank=True, + multiple=True, + stattypes=None, + enumslen=None, + ) + + def _get_datasets_input_data_specification_dto(): return InputDataSpecificationDTO( label="Set of data to use.", @@ -150,9 +163,15 @@ def _convert_inputdata_specifications_to_dto(spec: InputDataSpecifications): # These parameters are not added by the algorithm developer. y = _convert_inputdata_specification_to_dto(spec.y) x = _convert_inputdata_specification_to_dto(spec.x) if spec.x else None + validation_datasets = ( + _get_valiadtion_datasets_input_data_specification_dto() + if spec.validation + else None + ) return InputDataSpecificationsDTO( y=y, x=x, + validation_datasets=validation_datasets, data_model=_get_data_model_input_data_specification_dto(), datasets=_get_datasets_input_data_specification_dto(), filter=_get_filters_input_data_specification_dto(), diff --git a/exareme2/controller/services/flower/controller.py b/exareme2/controller/services/flower/controller.py index 5cd51c282..d71116df0 100644 --- a/exareme2/controller/services/flower/controller.py +++ b/exareme2/controller/services/flower/controller.py @@ -1,5 +1,4 @@ import asyncio -import warnings from typing import Dict from typing import List @@ -51,16 +50,21 @@ def _create_worker_tasks_handler(self, request_id, worker_info: WorkerInfo): ) async def exec_algorithm(self, algorithm_name, algorithm_request_dto): - async with self.lock: + async with (self.lock): request_id = algorithm_request_dto.request_id context_id = UIDGenerator().get_a_uid() logger = ctrl_logger.get_request_logger(request_id) + datasets = algorithm_request_dto.inputdata.datasets + ( + algorithm_request_dto.inputdata.validation_datasets + if algorithm_request_dto.inputdata.validation_datasets + else [] + ) csv_paths_per_worker_id: Dict[ str, List[str] ] = self.worker_landscape_aggregator.get_csv_paths_per_worker_id( - algorithm_request_dto.inputdata.data_model, - algorithm_request_dto.inputdata.datasets, + algorithm_request_dto.inputdata.data_model, datasets ) + workers_info = [ self.worker_landscape_aggregator.get_worker_info(worker_id) for worker_id in csv_paths_per_worker_id @@ -93,7 +97,9 @@ async def exec_algorithm(self, algorithm_name, algorithm_request_dto): algorithm_name, len(task_handlers), str(server_address), - csv_paths_per_worker_id[server_id], + csv_paths_per_worker_id[server_id] + if algorithm_request_dto.inputdata.validation_datasets + else [], ) clients_pids = { handler.start_flower_client( diff --git a/exareme2/controller/services/worker_landscape_aggregator/worker_landscape_aggregator.py b/exareme2/controller/services/worker_landscape_aggregator/worker_landscape_aggregator.py index 15b75a2cb..25237b133 100644 --- a/exareme2/controller/services/worker_landscape_aggregator/worker_landscape_aggregator.py +++ b/exareme2/controller/services/worker_landscape_aggregator/worker_landscape_aggregator.py @@ -26,6 +26,7 @@ ) from exareme2.controller.workers_addresses import WorkersAddressesFactory from exareme2.utils import AttrDict +from exareme2.worker_communication import BadUserInput from exareme2.worker_communication import CommonDataElement from exareme2.worker_communication import CommonDataElements from exareme2.worker_communication import DataModelAttributes @@ -145,6 +146,7 @@ def get_csv_paths_per_worker_id( ].items() if dataset in datasets ] + for dataset_info in dataset_infos: if not dataset_info.csv_path: raise DatasetMissingCsvPathError() @@ -514,6 +516,37 @@ def get_cdes_per_data_model(self) -> DataModelsCDES: def get_datasets_locations(self) -> DatasetsLocations: return self._registries.data_model_registry.datasets_locations + def get_train_and_validation_datasets( + self, data_model: str + ) -> Tuple[List[str], List[str]]: + """ + Retrieves all available training and validation datasets for a specific data model. + + Parameters: + data_model (str): The data model for which to retrieve datasets. + + Returns: + Tuple[List[str], List[str]]: A tuple containing two lists: + - The first list contains training datasets. + - The second list contains validation datasets. + """ + training_datasets = [] + validation_datasets = [] + + if data_model not in self.get_datasets_locations().datasets_locations.keys(): + raise BadUserInput(f"Data model '{data_model}' does not exist.") + datasets_locations = self.get_datasets_locations().datasets_locations[ + data_model + ] + + for dataset, dataset_location in datasets_locations.items(): + if dataset_location.worker_id == self.get_global_worker().id: + validation_datasets.append(dataset) + else: + training_datasets.append(dataset) + + return training_datasets, validation_datasets + def get_all_available_datasets_per_data_model(self) -> Dict[str, List[str]]: return ( self._registries.data_model_registry.get_all_available_datasets_per_data_model() @@ -815,8 +848,8 @@ def _remove_incompatible_data_models_from_data_models_metadata_per_worker( data_models_metadata_per_worker: DataModelsMetadataPerWorker Returns ---------- - List[str] - The incompatible data models + DataModelsMetadataPerWorker + The data_models_metadata_per_worker but with removed the incompatible data models """ validation_dictionary = {} diff --git a/tests/algorithm_validation_tests/flower/test_logistic_regression.py b/tests/algorithm_validation_tests/flower/test_logistic_regression.py index de166f7d5..3b61ed5e6 100644 --- a/tests/algorithm_validation_tests/flower/test_logistic_regression.py +++ b/tests/algorithm_validation_tests/flower/test_logistic_regression.py @@ -15,8 +15,8 @@ def test_logistic_regression(get_algorithm_result): "ppmi7", "ppmi8", "ppmi9", - "ppmi_test", ], + "validation_datasets": ["ppmi_test"], "filters": None, }, "parameters": None, @@ -44,8 +44,8 @@ def test_logistic_regression_with_filters(get_algorithm_result): "ppmi7", "ppmi8", "ppmi9", - "ppmi_test", ], + "validation_datasets": ["ppmi_test"], "filters": { "condition": "AND", "rules": [ diff --git a/tests/algorithm_validation_tests/flower/test_mnist_logistic_regression.py b/tests/algorithm_validation_tests/flower/test_mnist_logistic_regression.py index 3bbaad8a9..a5bc5ba68 100644 --- a/tests/algorithm_validation_tests/flower/test_mnist_logistic_regression.py +++ b/tests/algorithm_validation_tests/flower/test_mnist_logistic_regression.py @@ -15,7 +15,6 @@ def test_mnist_logistic_regression(get_algorithm_result): "ppmi7", "ppmi8", "ppmi9", - "ppmi_test", ], "filters": None, }, @@ -24,5 +23,4 @@ def test_mnist_logistic_regression(get_algorithm_result): } input["type"] = "flower" algorithm_result = get_algorithm_result("mnist_logistic_regression", input) - assert "accuracy" in algorithm_result assert {"accuracy": 0.8486} == algorithm_result diff --git a/tests/algorithm_validation_tests/one_node_deployment_template.toml b/tests/algorithm_validation_tests/one_node_deployment_template.toml index b9bc263ef..95cee1484 100644 --- a/tests/algorithm_validation_tests/one_node_deployment_template.toml +++ b/tests/algorithm_validation_tests/one_node_deployment_template.toml @@ -9,7 +9,7 @@ monetdb_memory_limit = 4096 # MB algorithm_folders = "./exareme2/algorithms/exareme2,./exareme2/algorithms/flower,./tests/algorithms" -worker_landscape_aggregator_update_interval = 30 +worker_landscape_aggregator_update_interval = 300 flower_execution_timeout = 30 celery_tasks_timeout = 60 celery_cleanup_task_timeout=2 diff --git a/tests/prod_env_tests/test_flower_logisticregression_validation.py b/tests/prod_env_tests/test_flower_logisticregression_validation.py index 10c7cfe7c..0b4ed91f0 100644 --- a/tests/prod_env_tests/test_flower_logisticregression_validation.py +++ b/tests/prod_env_tests/test_flower_logisticregression_validation.py @@ -21,8 +21,8 @@ def test_logisticregression_algorithm(): "ppmi7", "ppmi8", "ppmi9", - "ppmi_test", ], + "validation_datasets": ["ppmi_test"], "filters": None, }, "parameters": {}, diff --git a/tests/standalone_tests/algorithms/flower/test_inputdata_preprocessing.py b/tests/standalone_tests/algorithms/flower/test_inputdata_preprocessing.py index 2c22f4330..46e3b695e 100644 --- a/tests/standalone_tests/algorithms/flower/test_inputdata_preprocessing.py +++ b/tests/standalone_tests/algorithms/flower/test_inputdata_preprocessing.py @@ -57,6 +57,7 @@ def test_get_input_success(self, mock_post, mock_get): { "data_model": "model", "datasets": ["dataset1"], + "validation_datasets": ["validation_dataset1"], "filters": None, "y": ["target"], "x": ["feature1"], diff --git a/tests/standalone_tests/algorithms/flower/test_processes_garbage_collect.py b/tests/standalone_tests/algorithms/flower/test_processes_garbage_collect.py index f9ab9c25a..c2ad46d3b 100644 --- a/tests/standalone_tests/algorithms/flower/test_processes_garbage_collect.py +++ b/tests/standalone_tests/algorithms/flower/test_processes_garbage_collect.py @@ -49,8 +49,8 @@ def test_processes_garbage_collect( "ppmi1", "ppmi2", "ppmi3", - "ppmi_test", ], + "validation_datasets": ["ppmi_test"], "filters": None, }, "type": "flower", diff --git a/tests/standalone_tests/controller/services/api/test_validate_algorithm_request.py b/tests/standalone_tests/controller/services/api/test_validate_algorithm_request.py index 53dd25416..da4083736 100644 --- a/tests/standalone_tests/controller/services/api/test_validate_algorithm_request.py +++ b/tests/standalone_tests/controller/services/api/test_validate_algorithm_request.py @@ -1,3 +1,5 @@ +from unittest.mock import patch + import pytest from exareme2.algorithms.specifications import AlgorithmSpecification @@ -126,6 +128,10 @@ def worker_landscape_aggregator(): worker_id="sample_worker", csv_path="/opt/data/sample_dataset2.csv", ), + "sample_dataset3": DatasetLocation( + worker_id="globalworker", + csv_path="/opt/data/sample_dataset3.csv", + ), }, "sample_data_model:0.1": { "sample_dataset": DatasetLocation( @@ -163,6 +169,27 @@ def algorithms_specs(): ), ), ), + ( + "algorithm_with_y_int_and_validation", + AlgorithmType.EXAREME2, + ): AlgorithmSpecification( + name="algorithm_with_y_int_and_validation", + desc="algorithm_with_y_int_and_validation", + label="algorithm_with_y_int_and_validation", + enabled=True, + type=AlgorithmType.EXAREME2, + inputdata=InputDataSpecifications( + y=InputDataSpecification( + label="features", + desc="Features", + types=[InputDataType.REAL], + stattypes=[InputDataStatType.NUMERICAL], + notblank=True, + multiple=False, + ), + validation=True, + ), + ), ( "algorithm_with_x_int_and_y_text", AlgorithmType.EXAREME2, @@ -542,6 +569,19 @@ def get_parametrization_list_success_cases(): ), id="multiple datasets", ), + pytest.param( + "algorithm_with_y_int_and_validation", + AlgorithmRequestDTO( + type=AlgorithmType.EXAREME2, + inputdata=AlgorithmInputDataDTO( + data_model="data_model_with_all_cde_types:0.1", + datasets=["sample_dataset1", "sample_dataset2"], + validation_datasets=["sample_dataset3"], + y=["int_cde"], + ), + ), + id="multiple datasets and validation dataset", + ), pytest.param( "algorithm_with_x_int_and_y_text", AlgorithmRequestDTO( @@ -772,6 +812,16 @@ def get_parametrization_list_success_cases(): return parametrization_list +class WorkerInfo: + def __init__(self, id): + self.id = id + + +mocked_worker_info = WorkerInfo( + id="globalworker", +) + + @pytest.mark.parametrize( "algorithm_name, request_dto", get_parametrization_list_success_cases() ) @@ -782,15 +832,20 @@ def test_validate_algorithm_success( algorithms_specs, transformers_specs, ): - validate_algorithm_request( - algorithm_name=algorithm_name, - algorithm_request_dto=request_dto, - algorithms_specs=algorithms_specs, - transformers_specs=transformers_specs, - worker_landscape_aggregator=worker_landscape_aggregator, - smpc_enabled=False, - smpc_optional=False, - ) + with patch.object( + worker_landscape_aggregator, + "get_global_worker", + return_value=mocked_worker_info, + ): + validate_algorithm_request( + algorithm_name=algorithm_name, + algorithm_request_dto=request_dto, + algorithms_specs=algorithms_specs, + transformers_specs=transformers_specs, + worker_landscape_aggregator=worker_landscape_aggregator, + smpc_enabled=False, + smpc_optional=False, + ) def get_parametrization_list_exception_cases(): @@ -823,6 +878,23 @@ def get_parametrization_list_exception_cases(): ), id="Dataset does not exist.", ), + pytest.param( + "algorithm_with_y_int_and_validation", + AlgorithmRequestDTO( + type=AlgorithmType.EXAREME2, + inputdata=AlgorithmInputDataDTO( + data_model="data_model_with_all_cde_types:0.1", + datasets=["sample_dataset1"], + validation_datasets=["non_existing_dataset"], + y=["int_cde"], + ), + ), + ( + BadUserInput, + "Validation Datasets:.* could not be found for data_model:.*", + ), + id="validation dataset does not exist.", + ), pytest.param( "algorithm_with_y_int", AlgorithmRequestDTO( @@ -1327,6 +1399,39 @@ def get_parametrization_list_exception_cases(): ), id="flag does not have boolean value", ), + pytest.param( + "algorithm_with_y_int", + AlgorithmRequestDTO( + type=AlgorithmType.EXAREME2, + inputdata=AlgorithmInputDataDTO( + data_model="data_model_with_all_cde_types:0.1", + datasets=["sample_dataset1", "sample_dataset2"], + validation_datasets=["sample_dataset3"], + y=["int_cde"], + ), + ), + ( + BadUserInput, + "Validation is false, but validation datasets were provided.", + ), + id="Validation datasets on algorithm without validation", + ), + pytest.param( + "algorithm_with_y_int_and_validation", + AlgorithmRequestDTO( + type=AlgorithmType.EXAREME2, + inputdata=AlgorithmInputDataDTO( + data_model="data_model_with_all_cde_types:0.1", + datasets=["sample_dataset1", "sample_dataset2"], + y=["int_cde"], + ), + ), + ( + BadUserInput, + "Validation is true, but no validation datasets were provided.", + ), + id="Missing validation datasets on algorithm validation", + ), ] return parametrization_list @@ -1344,12 +1449,17 @@ def test_validate_algorithm_exceptions( ): exception_type, exception_message = exception with pytest.raises(exception_type, match=exception_message): - validate_algorithm_request( - algorithm_name=algorithm_name, - algorithm_request_dto=request_dto, - algorithms_specs=algorithms_specs, - transformers_specs=transformers_specs, - worker_landscape_aggregator=worker_landscape_aggregator, - smpc_enabled=False, - smpc_optional=False, - ) + with patch.object( + worker_landscape_aggregator, + "get_global_worker", + return_value=mocked_worker_info, + ): + validate_algorithm_request( + algorithm_name=algorithm_name, + algorithm_request_dto=request_dto, + algorithms_specs=algorithms_specs, + transformers_specs=transformers_specs, + worker_landscape_aggregator=worker_landscape_aggregator, + smpc_enabled=False, + smpc_optional=False, + ) diff --git a/tests/standalone_tests/controller/services/flower/test_pos_and_kw_args_in_algorithm_flow.py b/tests/standalone_tests/controller/services/flower/test_pos_and_kw_args_in_algorithm_flow.py index 0576fee42..34b6bef9f 100644 --- a/tests/standalone_tests/controller/services/flower/test_pos_and_kw_args_in_algorithm_flow.py +++ b/tests/standalone_tests/controller/services/flower/test_pos_and_kw_args_in_algorithm_flow.py @@ -25,8 +25,8 @@ def test_pos_and_kw_args_in_algorithm_flow( "ppmi1", "ppmi2", "ppmi3", - "ppmi_test", ], + "validation_datasets": ["ppmi_test"], "filters": None, }, "type": "flower",