Commit 45baf8f2 authored by Joaquin Torres's avatar Joaquin Torres

Completed comments

parent 11a14251
# Computing SHAP Interaction Values
# Author: Joaquín Torres Bravo
"""
Script to compute SHAP interaction values for chosen models.
"""
# Libraries # Libraries
# -------------------------------------------------------------------------------------------------------- # --------------------------------------------------------------------------------------------------------
import pandas as pd import pandas as pd
import numpy as np import numpy as np
import shap import shap # Explainability
import pickle import pickle # Loading/saving models
# Models
from xgboost import XGBClassifier from xgboost import XGBClassifier
from sklearn.ensemble import RandomForestClassifier, BaggingClassifier, AdaBoostClassifier from sklearn.ensemble import RandomForestClassifier, BaggingClassifier, AdaBoostClassifier
from sklearn.neural_network import MLPClassifier from sklearn.neural_network import MLPClassifier
...@@ -60,10 +67,6 @@ if __name__ == "__main__": ...@@ -60,10 +67,6 @@ if __name__ == "__main__":
X_test = data_dic['X_test_' + group] X_test = data_dic['X_test_' + group]
y_test = data_dic['y_test_' + group] y_test = data_dic['y_test_' + group]
for j, method in enumerate(['', '', 'over_', 'under_']): for j, method in enumerate(['', '', 'over_', 'under_']):
# Remove (used to isolate RF)
# if j != 1:
# print('Skip')
# continue
print(f"{group}-{method_names[j]}") print(f"{group}-{method_names[j]}")
method_name = method_names[j] method_name = method_names[j]
model_name = model_choices[method_name] model_name = model_choices[method_name]
...@@ -76,7 +79,7 @@ if __name__ == "__main__": ...@@ -76,7 +79,7 @@ if __name__ == "__main__":
if is_tree: if is_tree:
explainer = shap.TreeExplainer(fitted_model) explainer = shap.TreeExplainer(fitted_model)
# else: # else:
# explainer = shap.KernelExplainer(fitted_model.predict_proba, X_test[:500]) # explainer = shap.KernelExplainer...
# Compute shap values # Compute shap values
shap_interaction_values = explainer.shap_interaction_values(X_test) shap_interaction_values = explainer.shap_interaction_values(X_test)
# --------------------------------------------------------------------------------------------------------- # ---------------------------------------------------------------------------------------------------------
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment