Skip to content

Commit

Permalink
Split datasets to training datasets and validation datasets. (#490)
Browse files Browse the repository at this point in the history
  • Loading branch information
KFilippopolitis authored Jul 22, 2024
1 parent 644ef1b commit fcb3188
Show file tree
Hide file tree
Showing 17 changed files with 283 additions and 59 deletions.
8 changes: 8 additions & 0 deletions .github/workflows/algorithm_validation_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,9 @@ jobs:
- name: Controller logs
run: cat /tmp/exareme2/controller.out

- name: Globalworker logs
run: cat /tmp/exareme2/globalworker.out

- name: Localworker logs
run: cat /tmp/exareme2/localworker1.out

Expand All @@ -115,6 +118,11 @@ jobs:
with:
run: cat /tmp/exareme2/controller.out

- name: Globalworker logs (post run)
uses: webiny/[email protected]
with:
run: cat /tmp/exareme2/globalworker.out

- name: Localworker logs (post run)
uses: webiny/[email protected]
with:
Expand Down
3 changes: 2 additions & 1 deletion exareme2/algorithms/flower/inputdata_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
class Inputdata(BaseModel):
data_model: str
datasets: List[str]
validation_datasets: List[str]
filters: Optional[dict]
y: Optional[List[str]]
x: Optional[List[str]]
Expand All @@ -32,7 +33,7 @@ class Inputdata(BaseModel):
def apply_inputdata(df: pd.DataFrame, inputdata: Inputdata) -> pd.DataFrame:
if inputdata.filters:
df = apply_filter(df, inputdata.filters)
df = df[df["dataset"].isin(inputdata.datasets)]
df = df[df["dataset"].isin(inputdata.datasets + inputdata.validation_datasets)]
columns = inputdata.x + inputdata.y
df = df[columns]
df = df.dropna(subset=columns)
Expand Down
3 changes: 2 additions & 1 deletion exareme2/algorithms/flower/logistic_regression.json
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
],
"notblank": true,
"multiple": true
}
},
"validation": true
}
}
1 change: 1 addition & 0 deletions exareme2/algorithms/specifications.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ class InputDataSpecification(ImmutableBaseModel):
class InputDataSpecifications(ImmutableBaseModel):
y: InputDataSpecification
x: Optional[InputDataSpecification]
validation: Optional[bool]


class ParameterEnumSpecification(ImmutableBaseModel):
Expand Down
1 change: 1 addition & 0 deletions exareme2/controller/services/api/algorithm_request_dtos.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class Config:
class AlgorithmInputDataDTO(ImmutableBaseModel):
data_model: str
datasets: List[str]
validation_datasets: Optional[List[str]]
filters: Optional[dict]
y: Optional[List[str]]
x: Optional[List[str]]
Expand Down
93 changes: 69 additions & 24 deletions exareme2/controller/services/api/algorithm_request_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,24 +51,21 @@ def validate_algorithm_request(
algorithm_name, algorithm_request_dto.type, algorithms_specs
)

available_datasets_per_data_model = (
worker_landscape_aggregator.get_all_available_datasets_per_data_model()
)

_validate_data_model(
requested_data_model=algorithm_request_dto.inputdata.data_model,
available_datasets_per_data_model=available_datasets_per_data_model,
(
training_datasets,
validation_datasets,
) = worker_landscape_aggregator.get_train_and_validation_datasets(
algorithm_request_dto.inputdata.data_model
)

data_model_cdes = worker_landscape_aggregator.get_cdes(
algorithm_request_dto.inputdata.data_model
)

_validate_algorithm_request_body(
algorithm_request_dto=algorithm_request_dto,
algorithm_specs=algorithm_specs,
transformers_specs=transformers_specs,
available_datasets_per_data_model=available_datasets_per_data_model,
training_datasets=training_datasets,
validation_datasets=validation_datasets,
data_model_cdes=data_model_cdes,
smpc_enabled=smpc_enabled,
smpc_optional=smpc_optional,
Expand All @@ -89,15 +86,21 @@ def _validate_algorithm_request_body(
algorithm_request_dto: AlgorithmRequestDTO,
algorithm_specs: AlgorithmSpecification,
transformers_specs: Dict[str, TransformerSpecification],
available_datasets_per_data_model: Dict[str, List[str]],
training_datasets: List[str],
validation_datasets: List[str],
data_model_cdes: Dict[str, CommonDataElement],
smpc_enabled: bool,
smpc_optional: bool,
):
_ensure_validation_criteria(
algorithm_request_dto.inputdata.validation_datasets,
algorithm_specs.inputdata.validation,
)
_validate_inputdata(
inputdata=algorithm_request_dto.inputdata,
inputdata_specs=algorithm_specs.inputdata,
available_datasets_per_data_model=available_datasets_per_data_model,
training_datasets=training_datasets,
validation_datasets=validation_datasets,
data_model_cdes=data_model_cdes,
)

Expand All @@ -122,45 +125,87 @@ def _validate_algorithm_request_body(
)


def _validate_data_model(requested_data_model: str, available_datasets_per_data_model):
if requested_data_model not in available_datasets_per_data_model.keys():
raise BadUserInput(f"Data model '{requested_data_model}' does not exist.")
def _ensure_validation_criteria(validation_datasets: List[str], validation: bool):
"""
Validates the input based on the provided validation flag and datasets.
Parameters:
validation_datasets (List[str]): List of validation datasets.
validation (bool): Flag indicating if validation is required.
Raises:
BadUserInput: If the input conditions are not met.
"""
if not validation and validation_datasets:
raise BadUserInput(
"Validation is false, but validation datasets were provided."
)
elif validation and not validation_datasets:
raise BadUserInput(
"Validation is true, but no validation datasets were provided."
)


def _validate_inputdata(
inputdata: AlgorithmInputDataDTO,
inputdata_specs: InputDataSpecifications,
available_datasets_per_data_model: Dict[str, List[str]],
training_datasets: List[str],
validation_datasets: List[str],
data_model_cdes: Dict[str, CommonDataElement],
):
_validate_inputdata_dataset(
_validate_inputdata_training_datasets(
requested_data_model=inputdata.data_model,
requested_datasets=inputdata.datasets,
available_datasets_per_data_model=available_datasets_per_data_model,
training_datasets=training_datasets,
)
_validate_inputdata_validation_datasets(
requested_data_model=inputdata.data_model,
requested_validation_datasets=inputdata.validation_datasets,
validation_datasets=validation_datasets,
)
_validate_inputdata_filter(inputdata.data_model, inputdata.filters, data_model_cdes)
_validate_algorithm_inputdatas(inputdata, inputdata_specs, data_model_cdes)


def _validate_inputdata_dataset(
def _validate_inputdata_training_datasets(
requested_data_model: str,
requested_datasets: List[str],
available_datasets_per_data_model: Dict[str, List[str]],
training_datasets: List[str],
):
"""
Validates that the dataset values exist and that the datasets belong in the data_model.
Validates that the dataset values exist and that the datasets.
"""
non_existing_datasets = [
dataset
for dataset in requested_datasets
if dataset not in available_datasets_per_data_model[requested_data_model]
dataset for dataset in requested_datasets if dataset not in training_datasets
]
if non_existing_datasets:
raise BadUserInput(
f"Datasets:'{non_existing_datasets}' could not be found for data_model:{requested_data_model}"
)


def _validate_inputdata_validation_datasets(
requested_data_model: str,
requested_validation_datasets: List[str],
validation_datasets: List[str],
):
"""
Validates that the validation dataset values exist and that the validation_datasets.
"""
if not requested_validation_datasets:
return

non_existing_datasets = [
dataset
for dataset in requested_validation_datasets
if dataset not in validation_datasets
]
if non_existing_datasets:
raise BadUserInput(
f"Validation Datasets:'{non_existing_datasets}' could not be found for data_model:{requested_data_model}"
)


def _validate_inputdata_filter(data_model, filter, data_model_cdes):
"""
Validates that the filter provided have the correct format
Expand Down
19 changes: 19 additions & 0 deletions exareme2/controller/services/api/algorithm_spec_dtos.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class InputDataSpecificationsDTO(ImmutableBaseModel):
filter: InputDataSpecificationDTO
y: InputDataSpecificationDTO
x: Optional[InputDataSpecificationDTO]
validation_datasets: Optional[InputDataSpecificationDTO]


class ParameterEnumSpecificationDTO(ImmutableBaseModel):
Expand Down Expand Up @@ -121,6 +122,18 @@ def _get_data_model_input_data_specification_dto():
)


def _get_valiadtion_datasets_input_data_specification_dto():
return InputDataSpecificationDTO(
label="Set of data to validate.",
desc="The set of data to validate the algorithm model on.",
types=[InputDataType.TEXT],
notblank=True,
multiple=True,
stattypes=None,
enumslen=None,
)


def _get_datasets_input_data_specification_dto():
return InputDataSpecificationDTO(
label="Set of data to use.",
Expand Down Expand Up @@ -150,9 +163,15 @@ def _convert_inputdata_specifications_to_dto(spec: InputDataSpecifications):
# These parameters are not added by the algorithm developer.
y = _convert_inputdata_specification_to_dto(spec.y)
x = _convert_inputdata_specification_to_dto(spec.x) if spec.x else None
validation_datasets = (
_get_valiadtion_datasets_input_data_specification_dto()
if spec.validation
else None
)
return InputDataSpecificationsDTO(
y=y,
x=x,
validation_datasets=validation_datasets,
data_model=_get_data_model_input_data_specification_dto(),
datasets=_get_datasets_input_data_specification_dto(),
filter=_get_filters_input_data_specification_dto(),
Expand Down
16 changes: 11 additions & 5 deletions exareme2/controller/services/flower/controller.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import asyncio
import warnings
from typing import Dict
from typing import List

Expand Down Expand Up @@ -51,16 +50,21 @@ def _create_worker_tasks_handler(self, request_id, worker_info: WorkerInfo):
)

async def exec_algorithm(self, algorithm_name, algorithm_request_dto):
async with self.lock:
async with (self.lock):
request_id = algorithm_request_dto.request_id
context_id = UIDGenerator().get_a_uid()
logger = ctrl_logger.get_request_logger(request_id)
datasets = algorithm_request_dto.inputdata.datasets + (
algorithm_request_dto.inputdata.validation_datasets
if algorithm_request_dto.inputdata.validation_datasets
else []
)
csv_paths_per_worker_id: Dict[
str, List[str]
] = self.worker_landscape_aggregator.get_csv_paths_per_worker_id(
algorithm_request_dto.inputdata.data_model,
algorithm_request_dto.inputdata.datasets,
algorithm_request_dto.inputdata.data_model, datasets
)

workers_info = [
self.worker_landscape_aggregator.get_worker_info(worker_id)
for worker_id in csv_paths_per_worker_id
Expand Down Expand Up @@ -93,7 +97,9 @@ async def exec_algorithm(self, algorithm_name, algorithm_request_dto):
algorithm_name,
len(task_handlers),
str(server_address),
csv_paths_per_worker_id[server_id],
csv_paths_per_worker_id[server_id]
if algorithm_request_dto.inputdata.validation_datasets
else [],
)
clients_pids = {
handler.start_flower_client(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
)
from exareme2.controller.workers_addresses import WorkersAddressesFactory
from exareme2.utils import AttrDict
from exareme2.worker_communication import BadUserInput
from exareme2.worker_communication import CommonDataElement
from exareme2.worker_communication import CommonDataElements
from exareme2.worker_communication import DataModelAttributes
Expand Down Expand Up @@ -145,6 +146,7 @@ def get_csv_paths_per_worker_id(
].items()
if dataset in datasets
]

for dataset_info in dataset_infos:
if not dataset_info.csv_path:
raise DatasetMissingCsvPathError()
Expand Down Expand Up @@ -514,6 +516,37 @@ def get_cdes_per_data_model(self) -> DataModelsCDES:
def get_datasets_locations(self) -> DatasetsLocations:
return self._registries.data_model_registry.datasets_locations

def get_train_and_validation_datasets(
self, data_model: str
) -> Tuple[List[str], List[str]]:
"""
Retrieves all available training and validation datasets for a specific data model.
Parameters:
data_model (str): The data model for which to retrieve datasets.
Returns:
Tuple[List[str], List[str]]: A tuple containing two lists:
- The first list contains training datasets.
- The second list contains validation datasets.
"""
training_datasets = []
validation_datasets = []

if data_model not in self.get_datasets_locations().datasets_locations.keys():
raise BadUserInput(f"Data model '{data_model}' does not exist.")
datasets_locations = self.get_datasets_locations().datasets_locations[
data_model
]

for dataset, dataset_location in datasets_locations.items():
if dataset_location.worker_id == self.get_global_worker().id:
validation_datasets.append(dataset)
else:
training_datasets.append(dataset)

return training_datasets, validation_datasets

def get_all_available_datasets_per_data_model(self) -> Dict[str, List[str]]:
return (
self._registries.data_model_registry.get_all_available_datasets_per_data_model()
Expand Down Expand Up @@ -815,8 +848,8 @@ def _remove_incompatible_data_models_from_data_models_metadata_per_worker(
data_models_metadata_per_worker: DataModelsMetadataPerWorker
Returns
----------
List[str]
The incompatible data models
DataModelsMetadataPerWorker
The data_models_metadata_per_worker but with removed the incompatible data models
"""
validation_dictionary = {}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ def test_logistic_regression(get_algorithm_result):
"ppmi7",
"ppmi8",
"ppmi9",
"ppmi_test",
],
"validation_datasets": ["ppmi_test"],
"filters": None,
},
"parameters": None,
Expand Down Expand Up @@ -44,8 +44,8 @@ def test_logistic_regression_with_filters(get_algorithm_result):
"ppmi7",
"ppmi8",
"ppmi9",
"ppmi_test",
],
"validation_datasets": ["ppmi_test"],
"filters": {
"condition": "AND",
"rules": [
Expand Down
Loading

0 comments on commit fcb3188

Please sign in to comment.