Skip to content

Commit

Permalink
Fix of copying categorical features with choosen idxes
Browse files Browse the repository at this point in the history
  • Loading branch information
aPovidlo committed Jul 22, 2024
1 parent e0b4ee7 commit 8f37a56
Showing 1 changed file with 19 additions and 12 deletions.
31 changes: 19 additions & 12 deletions fedot/core/data/data_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,21 +33,28 @@ def _split_input_data_by_indexes(origin_input_data: Union[InputData, MultiModalD
target = np.take(origin_input_data.target, index, 0)
features = np.take(origin_input_data.features, index, 0)

if origin_input_data.categorical_features is not None:
categorical_features = np.take(origin_input_data.categorical_features, index, 0)
else:
categorical_features = origin_input_data.categorical_features

if retain_first_target and len(target.shape) > 1:
target = target[:, 0]

data = InputData(idx=idx,
features=features,
target=target,
task=deepcopy(origin_input_data.task),
data_type=origin_input_data.data_type,
supplementary_data=origin_input_data.supplementary_data,
categorical_features=origin_input_data.categorical_features,
categorical_idx=origin_input_data.categorical_idx,
numerical_idx=origin_input_data.numerical_idx,
encoded_idx=origin_input_data.encoded_idx,
features_names=origin_input_data.features_names,
)
data = InputData(
idx=idx,
features=features,
target=target,
task=deepcopy(origin_input_data.task),
data_type=origin_input_data.data_type,
supplementary_data=origin_input_data.supplementary_data,
categorical_features=categorical_features,
categorical_idx=origin_input_data.categorical_idx,
numerical_idx=origin_input_data.numerical_idx,
encoded_idx=origin_input_data.encoded_idx,
features_names=origin_input_data.features_names,
)

return data
else:
raise TypeError(f'Unknown data type {type(origin_input_data)}')
Expand Down

0 comments on commit 8f37a56

Please sign in to comment.