diff --git a/fedot/core/data/data_split.py b/fedot/core/data/data_split.py index b3ba520fce..73b4f21da2 100644 --- a/fedot/core/data/data_split.py +++ b/fedot/core/data/data_split.py @@ -33,21 +33,28 @@ def _split_input_data_by_indexes(origin_input_data: Union[InputData, MultiModalD target = np.take(origin_input_data.target, index, 0) features = np.take(origin_input_data.features, index, 0) + if origin_input_data.categorical_features is not None: + categorical_features = np.take(origin_input_data.categorical_features, index, 0) + else: + categorical_features = origin_input_data.categorical_features + if retain_first_target and len(target.shape) > 1: target = target[:, 0] - data = InputData(idx=idx, - features=features, - target=target, - task=deepcopy(origin_input_data.task), - data_type=origin_input_data.data_type, - supplementary_data=origin_input_data.supplementary_data, - categorical_features=origin_input_data.categorical_features, - categorical_idx=origin_input_data.categorical_idx, - numerical_idx=origin_input_data.numerical_idx, - encoded_idx=origin_input_data.encoded_idx, - features_names=origin_input_data.features_names, - ) + data = InputData( + idx=idx, + features=features, + target=target, + task=deepcopy(origin_input_data.task), + data_type=origin_input_data.data_type, + supplementary_data=origin_input_data.supplementary_data, + categorical_features=categorical_features, + categorical_idx=origin_input_data.categorical_idx, + numerical_idx=origin_input_data.numerical_idx, + encoded_idx=origin_input_data.encoded_idx, + features_names=origin_input_data.features_names, + ) + return data else: raise TypeError(f'Unknown data type {type(origin_input_data)}')