Commit a8431993 authored by Joaquin Torres's avatar Joaquin Torres

Script to compute shap interaction values

parent 484d8ae1
# Libraries
# --------------------------------------------------------------------------------------------------------
import pandas as pd
import numpy as np
import shap
import pickle
from xgboost import XGBClassifier
from sklearn.ensemble import RandomForestClassifier, BaggingClassifier, AdaBoostClassifier
from sklearn.neural_network import MLPClassifier
from sklearn.svm import SVC
from sklearn.linear_model import LogisticRegression
from sklearn.tree import DecisionTreeClassifier
# --------------------------------------------------------------------------------------------------------
# Reading test data
# --------------------------------------------------------------------------------------------------------
def read_test_data(attribute_names):
# Load test data
X_test_pre = np.load('../gen_train_data/data/output/pre/X_test_pre.npy', allow_pickle=True)
y_test_pre = np.load('../gen_train_data/data/output/pre/y_test_pre.npy', allow_pickle=True)
X_test_post = np.load('../gen_train_data/data/output/post/X_test_post.npy', allow_pickle=True)
y_test_post = np.load('../gen_train_data/data/output/post/y_test_post.npy', allow_pickle=True)
# Type conversion needed
data_dic = {
"X_test_pre": pd.DataFrame(X_test_pre, columns=attribute_names).convert_dtypes(),
"y_test_pre": y_test_pre,
"X_test_post": pd.DataFrame(X_test_post, columns=attribute_names).convert_dtypes(),
"y_test_post": y_test_post,
}
return data_dic
# --------------------------------------------------------------------------------------------------------
if __name__ == "__main__":
# Setup
# --------------------------------------------------------------------------------------------------------
# Retrieve attribute names in order
attribute_names = list(np.load('../gen_train_data/data/output/attributes.npy', allow_pickle=True))
# Reading data
data_dic = read_test_data(attribute_names)
method_names = {
0: "ORIG",
1: "ORIG_CW",
2: "OVER",
3: "UNDER"
}
model_choices = {
"ORIG": "XGB",
"ORIG_CW": "RF",
"OVER": "XGB",
"UNDER": "XGB"
}
# --------------------------------------------------------------------------------------------------------
# Shap value generation
# --------------------------------------------------------------------------------------------------------
for i, group in enumerate(['pre', 'post']):
# Get test dataset based on group, add column names
X_test = data_dic['X_test_' + group]
y_test = data_dic['y_test_' + group]
for j, method in enumerate(['', '', 'over_', 'under_']):
print(f"{group}-{method_names[j]}")
method_name = method_names[j]
model_name = model_choices[method_name]
model_path = f"./output/fitted_models/{group}_{method_names[j]}_{model_name}.pkl"
# Load the fitted model from disk
with open(model_path, 'rb') as file:
fitted_model = pickle.load(file)
# Check if we are dealing with a tree vs nn model
is_tree = model_name not in ['LR', 'SVM', 'MLP']
if is_tree:
explainer = shap.TreeExplainer(fitted_model)
# else:
# explainer = shap.KernelExplainer(fitted_model.predict_proba, X_test[:500])
# Compute shap values
shap_interaction_values = explainer.shap_interaction_values(X_test)
# ---------------------------------------------------------------------------------------------------------
# Save results
np.save(f"./output/shap_inter_values/{group}_{method_names[j]}", shap_interaction_values)
# --------------------------------------------------------------------------------------------------------
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment