diff --git a/explicability/compute_shap_inter_vals.py b/explicability/compute_shap_inter_vals.py new file mode 100644 index 0000000000000000000000000000000000000000..43a75eebbb97aa0fd7df3aabe3acc0b30f557517 --- /dev/null +++ b/explicability/compute_shap_inter_vals.py @@ -0,0 +1,81 @@ +# Libraries +# -------------------------------------------------------------------------------------------------------- +import pandas as pd +import numpy as np +import shap +import pickle +from xgboost import XGBClassifier +from sklearn.ensemble import RandomForestClassifier, BaggingClassifier, AdaBoostClassifier +from sklearn.neural_network import MLPClassifier +from sklearn.svm import SVC +from sklearn.linear_model import LogisticRegression +from sklearn.tree import DecisionTreeClassifier +# -------------------------------------------------------------------------------------------------------- + +# Reading test data +# -------------------------------------------------------------------------------------------------------- +def read_test_data(attribute_names): + # Load test data + X_test_pre = np.load('../gen_train_data/data/output/pre/X_test_pre.npy', allow_pickle=True) + y_test_pre = np.load('../gen_train_data/data/output/pre/y_test_pre.npy', allow_pickle=True) + X_test_post = np.load('../gen_train_data/data/output/post/X_test_post.npy', allow_pickle=True) + y_test_post = np.load('../gen_train_data/data/output/post/y_test_post.npy', allow_pickle=True) + + # Type conversion needed + data_dic = { + "X_test_pre": pd.DataFrame(X_test_pre, columns=attribute_names).convert_dtypes(), + "y_test_pre": y_test_pre, + "X_test_post": pd.DataFrame(X_test_post, columns=attribute_names).convert_dtypes(), + "y_test_post": y_test_post, + } + return data_dic +# -------------------------------------------------------------------------------------------------------- + +if __name__ == "__main__": + + # Setup + # -------------------------------------------------------------------------------------------------------- + # Retrieve attribute names in order + attribute_names = list(np.load('../gen_train_data/data/output/attributes.npy', allow_pickle=True)) + # Reading data + data_dic = read_test_data(attribute_names) + method_names = { + 0: "ORIG", + 1: "ORIG_CW", + 2: "OVER", + 3: "UNDER" + } + model_choices = { + "ORIG": "XGB", + "ORIG_CW": "RF", + "OVER": "XGB", + "UNDER": "XGB" + } + # -------------------------------------------------------------------------------------------------------- + + # Shap value generation + # -------------------------------------------------------------------------------------------------------- + for i, group in enumerate(['pre', 'post']): + # Get test dataset based on group, add column names + X_test = data_dic['X_test_' + group] + y_test = data_dic['y_test_' + group] + for j, method in enumerate(['', '', 'over_', 'under_']): + print(f"{group}-{method_names[j]}") + method_name = method_names[j] + model_name = model_choices[method_name] + model_path = f"./output/fitted_models/{group}_{method_names[j]}_{model_name}.pkl" + # Load the fitted model from disk + with open(model_path, 'rb') as file: + fitted_model = pickle.load(file) + # Check if we are dealing with a tree vs nn model + is_tree = model_name not in ['LR', 'SVM', 'MLP'] + if is_tree: + explainer = shap.TreeExplainer(fitted_model) + # else: + # explainer = shap.KernelExplainer(fitted_model.predict_proba, X_test[:500]) + # Compute shap values + shap_interaction_values = explainer.shap_interaction_values(X_test) + # --------------------------------------------------------------------------------------------------------- + # Save results + np.save(f"./output/shap_inter_values/{group}_{method_names[j]}", shap_interaction_values) + # -------------------------------------------------------------------------------------------------------- \ No newline at end of file