From 45baf8f224cd9ec7f906eac2c9f11387a08e9e85 Mon Sep 17 00:00:00 2001 From: joaquintb Date: Mon, 8 Jul 2024 12:29:08 +0200 Subject: [PATCH] Completed comments --- explainability/compute_shap_inter_vals.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/explainability/compute_shap_inter_vals.py b/explainability/compute_shap_inter_vals.py index a16d8c3..cebd243 100644 --- a/explainability/compute_shap_inter_vals.py +++ b/explainability/compute_shap_inter_vals.py @@ -1,9 +1,16 @@ +# Computing SHAP Interaction Values +# Author: JoaquĆ­n Torres Bravo +""" + Script to compute SHAP interaction values for chosen models. +""" + # Libraries # -------------------------------------------------------------------------------------------------------- import pandas as pd import numpy as np -import shap -import pickle +import shap # Explainability +import pickle # Loading/saving models +# Models from xgboost import XGBClassifier from sklearn.ensemble import RandomForestClassifier, BaggingClassifier, AdaBoostClassifier from sklearn.neural_network import MLPClassifier @@ -60,10 +67,6 @@ if __name__ == "__main__": X_test = data_dic['X_test_' + group] y_test = data_dic['y_test_' + group] for j, method in enumerate(['', '', 'over_', 'under_']): - # Remove (used to isolate RF) - # if j != 1: - # print('Skip') - # continue print(f"{group}-{method_names[j]}") method_name = method_names[j] model_name = model_choices[method_name] @@ -76,7 +79,7 @@ if __name__ == "__main__": if is_tree: explainer = shap.TreeExplainer(fitted_model) # else: - # explainer = shap.KernelExplainer(fitted_model.predict_proba, X_test[:500]) + # explainer = shap.KernelExplainer... # Compute shap values shap_interaction_values = explainer.shap_interaction_values(X_test) # --------------------------------------------------------------------------------------------------------- -- 2.24.1