Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix shap value calculation and plot #158

Merged
merged 9 commits into from
Sep 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions nkululeko/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-


from typing import List
import configparser
import time
Expand All @@ -26,10 +27,15 @@

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.metrics import(
RocCurveDisplay,
balanced_accuracy_score,
classification_report,
f1_score
auc,
roc_auc_score,
roc_curve
)

from nkululeko.constants import VERSION
Expand Down Expand Up @@ -289,9 +295,7 @@ def ensemble_predictions(
uar = balanced_accuracy_score(truth, predicted)
acc = (truth == predicted).mean()
# print classification report
Util("ensemble").debug(f"\n {classification_report(truth, predicted)}")
# f1 = f1_score(truth, predicted, pos_label='p')
# Util("ensemble").debug(f"F1: {f1:.3f}")
Util("ensemble").debug(f"\n {classification_report(truth, predicted, digits=4)}")
Util("ensemble").debug(f"{method}: UAR: {uar:.3f}, ACC: {acc:.3f}")

return ensemble_preds
Expand Down
4 changes: 3 additions & 1 deletion nkululeko/explore.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,9 @@ def main(src_dir):
# these investigations need features to explore
expr.extract_feats()
needs_feats = True
# explore
# explore
expr.init_runmanager()
expr.runmgr.do_runs()
expr.analyse_features(needs_feats)
expr.store_report()
print("DONE")
Expand Down
20 changes: 17 additions & 3 deletions nkululeko/feat_extract/feats_analyser.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,19 +50,32 @@ def analyse_shap(self, model):

name = "my_shap_values"
if not self.util.exist_pickle(name):

# get model name
model_name = self.util.get_model_type()
if hasattr(model, "predict_shap"):
model_func = model.predict_shap
elif hasattr(model, "clf"):
model_func = model.clf.predict
else:
raise Exception("Model not supported for SHAP analysis")

self.util.debug(f"using SHAP explainer for {model_name} model")

explainer = shap.Explainer(
model.predict_shap,
model_func,
self.features,
output_names=glob_conf.labels,
algorithm="permutation",
npermutations=5,
)

self.util.debug("computing SHAP values...")
shap_values = explainer(self.features)
self.util.to_pickle(shap_values, name)
else:
shap_values = self.util.from_pickle(name)
# plt.figure()
plt.close('all')
plt.tight_layout()
shap.plots.bar(shap_values)
fig_dir = self.util.get_path("fig_dir") + "../" # one up because of the runs
Expand All @@ -71,7 +84,8 @@ def analyse_shap(self, model):
filename = f"_SHAP_{model.name}"
filename = f"{fig_dir}{exp_name}{filename}.{format}"
plt.savefig(filename)
self.util.debug(f"plotted SHAP feature importance tp {filename}")
plt.close()
self.util.debug(f"plotted SHAP feature importance to {filename}")

def analyse(self):
models = ast.literal_eval(self.util.config_val("EXPL", "model", "['log_reg']"))
Expand Down
4 changes: 2 additions & 2 deletions nkululeko/modelrunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ def do_epochs(self):
# epochs are handled by Huggingface API
self.model.train()
report = self.model.predict()
# todo: findout the best epoch, no need
# since oad_best_model_at_end is given in training args
# todo: findout the best epoch -> no need
# since load_best_model_at_end is given in training args
epoch = epoch_num
report.set_id(self.run, epoch)
plot_name = self.util.get_plot_name() + f"_{self.run}_{epoch:03d}_cnf"
Expand Down
7 changes: 4 additions & 3 deletions nkululeko/reporting/reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ def print_results(self, epoch=None):
)
# print classifcation report in console
self.util.debug(
f"\n {classification_report(self.truths, self.preds, target_names=labels)}"
f"\n {classification_report(self.truths, self.preds, target_names=labels, digits=4)}"
)
except ValueError as e:
self.util.debug(
Expand All @@ -422,16 +422,17 @@ def print_results(self, epoch=None):
if len(np.unique(self.truths)) == 2:
fpr, tpr, _ = roc_curve(self.truths, self.preds)
auc_score = auc(fpr, tpr)
plot_path = f"{fig_dir}{self.util.get_exp_name()}_{epoch}{self.filenameadd}_roc.{self.format}"
plt.figure()
display = RocCurveDisplay(
fpr=fpr,
tpr=tpr,
roc_auc=auc_score,
estimator_name=f"{self.model_type} estimator",
)
# save plot
plot_path = f"{fig_dir}{self.util.get_exp_name()}_{epoch}{self.filenameadd}_roc.{self.format}"
display.plot(ax=None)
plt.savefig(plot_path)
plt.close()
self.util.debug(f"Saved ROC curve to {plot_path}")
pauc_score = roc_auc_score(self.truths, self.preds, max_fpr=0.1)
auc_pauc = f"auc: {auc_score:.3f}, pauc: {pauc_score:.3f} from epoch: {epoch}"
Expand Down
12 changes: 10 additions & 2 deletions run_test2.sh
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ function RunTest {
fi
}

# resample before performing other tests
resample_ini_files=(
exp_polish_gmm.ini
)
# test basic nkululeko
nkululeko_ini_files=(
exp_emodb_os_praat_xgb.ini
Expand All @@ -70,6 +74,7 @@ nkululeko_ini_files=(
exp_emodb_os_mlp.ini
exp_agedb_os_xgr.ini
exp_agedb_os_mlp.ini
exp_polish_gmm.ini
)

# test augmentation
Expand Down Expand Up @@ -114,6 +119,7 @@ explore_ini_files=(
exp_emodb_explore_scatter.ini
exp_emodb_explore_features.ini
exp_agedb_explore_data.ini
exp_polish_gmm.ini # shap
exp_explore.ini # test splotlight
)

Expand All @@ -132,8 +138,10 @@ start_time=$(date +%s)
if [ "$1" == "all" ]; then
modules=(nkululeko augment predict demo test multidb explore)
elif [ "$1" == "-spotlight" ]; then
modules=(nkululeko augment predict demo test multidb explore)
unset explore_ini_files[-1] # Exclude INI file for spotlight
modules=(resample nkululeko augment predict demo test multidb explore)
# unset last two ini files to exclude spotlight and shap
unset explore_ini_files[-1] # Exclude INI file for spotlight
unset explore_ini_files[-1] # and shap
else
modules=("$@")
fi
Expand Down
8 changes: 5 additions & 3 deletions tests/exp_polish_gmm.ini
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[EXP]
root = /tmp/results/
root = ./tests/results/
name = exp_polish_os
save = True
[DATA]
databases = ['train', 'dev', 'test']
train = ./data/polish/polish_train.csv
Expand Down Expand Up @@ -28,9 +29,10 @@ scale = standard
type = gmm
GMM_components = 3
GMM_covariance_type = diag
; save = True
save = True
[RESAMPLE]
replace = True
[PLOT]
; do not plot anything

[EXPL]
shap = True
Loading