Skip to content

Commit

Permalink
Update feature importance
Browse files Browse the repository at this point in the history
  • Loading branch information
aPovidlo committed Jul 23, 2024
1 parent 604b747 commit 74545d0
Showing 1 changed file with 25 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ def predict(self, input_data: InputData):

return prediction

def get_feature_importance(self) -> list:
return self.model.features_importances_

@staticmethod
def convert_to_dataframe(data: Optional[InputData]):
dataframe = pd.DataFrame(data=data.features, columns=data.features_names)
Expand All @@ -74,6 +77,11 @@ def convert_to_dataframe(data: Optional[InputData]):

return dataframe

def plot_feature_importance(self):
plot_feature_importance(
self.model.feature_names_, self.model.get_boosters().features_importances_
)


class FedotXGBoostClassificationImplementation(FedotXGBoostImplementation):
def __init__(self, params: Optional[OperationParameters] = None):
Expand Down Expand Up @@ -162,6 +170,13 @@ def load_model(self, path):
self.model = CatBoostClassifier()
self.model.load_model(path)

def get_feature_importance(self) -> (list, list):
""" Return feature importance -> (feature_id (string), feature_importance (float)) """
return self.model.get_feature_importance(prettified=True)

def plot_feature_importance(self):
plot_feature_importance(self.model.feature_names_, self.model.features_importances_)


class FedotCatBoostClassificationImplementation(FedotCatBoostImplementation):
def __init__(self, params: Optional[OperationParameters] = None):
Expand All @@ -177,20 +192,18 @@ def predict_proba(self, input_data: InputData):
prediction = self.model.predict_proba(input_data.get_not_encoded_data().features)
return prediction

def get_feature_importance(self):
return self.model.get_feature_importance(prettified=True)

def plot_feature_importance(self):
fi = pd.DataFrame(index=self.model.feature_names_)
fi['importance'] = self.model.feature_importances_

fi.loc[fi['importance'] > 0.1].sort_values('importance').plot(
kind='barh', figsize=(16, 9), title='Feature Importance')

plt.show()


class FedotCatBoostRegressionImplementation(FedotCatBoostImplementation):
def __init__(self, params: Optional[OperationParameters] = None):
super().__init__(params)
self.model = CatBoostRegressor(**self.model_params)


def plot_feature_importance(feature_names, feature_importance):
fi = pd.DataFrame(index=feature_names)
fi['importance'] = feature_importance

fi.loc[fi['importance'] > 0.1].sort_values('importance').plot(
kind='barh', figsize=(16, 9), title='Feature Importance')

plt.show()

0 comments on commit 74545d0

Please sign in to comment.