Skip to content

Commit

Permalink
Adding splitting test with cat features
Browse files Browse the repository at this point in the history
  • Loading branch information
aPovidlo committed Jul 23, 2024
1 parent 8f37a56 commit 83328db
Showing 1 changed file with 34 additions and 0 deletions.
34 changes: 34 additions & 0 deletions test/unit/data/test_data_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pandas as pd
import pytest

from fedot.api.api_utils.api_data import ApiDataProcessor
from fedot.core.data.data import InputData
from fedot.core.data.data_split import train_test_data_setup
from fedot.core.data.multi_modal import MultiModalData
Expand All @@ -18,6 +19,8 @@
from test.unit.tasks.test_forecasting import get_ts_data

TABULAR_SIMPLE = {'train_features_size': (8, 5), 'test_features_size': (2, 5), 'test_idx': (8, 9)}
TABULAR_CATEGORICAL = {'train_features_size': (11, 26), 'test_features_size': (3, 26), 'test_idx': (11, 12, 13),
'train_category_size': (11, 4), 'test_category_size': (3, 4)}
TS_SIMPLE = {'train_features_size': (18,), 'test_features_size': (18,), 'test_idx': (18, 19)}
TEXT_SIMPLE = {'train_features_size': (8,), 'test_features_size': (2,), 'test_idx': (8, 9)}
IMAGE_SIMPLE = {'train_features_size': (8, 5, 5, 2), 'test_features_size': (2, 5, 5, 2), 'test_idx': (8, 9)}
Expand Down Expand Up @@ -107,6 +110,32 @@ def get_balanced_data_to_test_mismatch():
return input_data


def get_tabular_classification_data_with_cats():
task = Task(TaskTypesEnum.classification)
x = np.array([[0, 0, 15, 'cat', 'left'],
[0, 1, 2, 'cat', 'right'],
[8, 12, 0, 'dog', 'left'],
[0, 1, 0, 'dog', 'right'],
[1, 1, 0, 'cat', 'left'],
[0, 11, 9, 'cow', 'right'],
[5, 1, 10, 'cat', 'left'],
[8, 16, 4, 'dog', 'right'],
[3, 1, 5, 'cat', 'left'],
[0, 1, 6, 'dog', 'right'],
[2, 7, 9, 'cat', 'left'],
[0, 1, 2, 'dog', 'right'],
[14, 1, 0, 'cat', 'right'],
[0, 4, 10, 'dog', 'left']])
y = np.array([0, 1, 1, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1])
input_data = InputData(idx=np.arange(0, len(x)), features=x,
target=y, task=task, data_type=DataTypesEnum.table)

data_preprocessor = ApiDataProcessor(task=task)
preprocessed_input_data = data_preprocessor.fit_transform(input_data)

return preprocessed_input_data


def check_shuffle(sample):
unique = np.unique(np.diff(sample.idx))
test_result = len(unique) > 1 or np.min(unique) > 1
Expand All @@ -133,6 +162,7 @@ def test_split_data():

@pytest.mark.parametrize('data_generator, expected_output',
[(get_tabular_classification_data, TABULAR_SIMPLE),
(get_tabular_classification_data_with_cats, TABULAR_CATEGORICAL),
(get_ts_data_to_forecast_two_elements, TS_SIMPLE),
(get_text_classification_data, TEXT_SIMPLE),
(get_image_classification_data, IMAGE_SIMPLE)])
Expand All @@ -144,6 +174,10 @@ def test_default_train_test_simple(data_generator: Callable, expected_output: di
assert train_data.features.shape == expected_output['train_features_size']
assert test_data.features.shape == expected_output['test_features_size']
assert tuple(test_data.idx) == expected_output['test_idx']
if 'train_category_size' in expected_output:
assert train_data.categorical_features.shape == expected_output['train_category_size']
if 'train_category_size' in expected_output:
assert test_data.categorical_features.shape == expected_output['test_category_size']


def test_multitarget_train_test_split():
Expand Down

0 comments on commit 83328db

Please sign in to comment.