Commit 36a3534e authored by Joaquin Torres's avatar Joaquin Torres

script prepared to compute shap values

parent e97c990a
......@@ -2,12 +2,111 @@
# --------------------------------------------------------------------------------------------------------
import pandas as pd
import numpy as np
import shap
from xgboost import XGBClassifier
from sklearn.metrics import confusion_matrix
from sklearn.metrics import f1_score, make_scorer, precision_score, recall_score, accuracy_score
from sklearn.ensemble import RandomForestClassifier, BaggingClassifier, AdaBoostClassifier
from sklearn.neural_network import MLPClassifier
from sklearn.svm import SVC
from sklearn.linear_model import LogisticRegression
from sklearn.tree import DecisionTreeClassifier
# --------------------------------------------------------------------------------------------------------
# Reading test and training data
# --------------------------------------------------------------------------------------------------------
def read_data():
# Load test data
X_test_pre = np.load('../gen_train_data/data/output/pre/X_test_pre.npy', allow_pickle=True)
y_test_pre = np.load('../gen_train_data/data/output/pre/y_test_pre.npy', allow_pickle=True)
X_test_post = np.load('../gen_train_data/data/output/post/X_test_post.npy', allow_pickle=True)
y_test_post = np.load('../gen_train_data/data/output/post/y_test_post.npy', allow_pickle=True)
# Load ORIGINAL training data
X_train_pre = np.load('../gen_train_data/data/output/pre/X_train_pre.npy', allow_pickle=True)
y_train_pre = np.load('../gen_train_data/data/output/pre/y_train_pre.npy', allow_pickle=True)
X_train_post = np.load('../gen_train_data/data/output/post/X_train_post.npy', allow_pickle=True)
y_train_post = np.load('../gen_train_data/data/output/post/y_train_post.npy', allow_pickle=True)
# Load oversampled training data
X_train_over_pre = np.load('../gen_train_data/data/output/pre/X_train_over_pre.npy', allow_pickle=True)
y_train_over_pre = np.load('../gen_train_data/data/output/pre/y_train_over_pre.npy', allow_pickle=True)
X_train_over_post = np.load('../gen_train_data/data/output/post/X_train_over_post.npy', allow_pickle=True)
y_train_over_post = np.load('../gen_train_data/data/output/post/y_train_over_post.npy', allow_pickle=True)
# Load undersampled training data
X_train_under_pre = np.load('../gen_train_data/data/output/pre/X_train_under_pre.npy', allow_pickle=True)
y_train_under_pre = np.load('../gen_train_data/data/output/pre/y_train_under_pre.npy', allow_pickle=True)
X_train_under_post = np.load('../gen_train_data/data/output/post/X_train_under_post.npy', allow_pickle=True)
y_train_under_post = np.load('../gen_train_data/data/output/post/y_train_under_post.npy', allow_pickle=True)
data_dic = {
"X_test_pre": X_test_pre,
"y_test_pre": y_test_pre,
"X_test_post": X_test_post,
"y_test_post": y_test_post,
"X_train_pre": X_train_pre,
"y_train_pre": y_train_pre,
"X_train_post": X_train_post,
"y_train_post": y_train_post,
"X_train_over_pre": X_train_over_pre,
"y_train_over_pre": y_train_over_pre,
"X_train_over_post": X_train_over_post,
"y_train_over_post": y_train_over_post,
"X_train_under_pre": X_train_under_pre,
"y_train_under_pre": y_train_under_pre,
"X_train_under_post": X_train_under_post,
"y_train_under_post": y_train_under_post,
}
return data_dic
# --------------------------------------------------------------------------------------------------------
if __name__ == "__main__":
# Setup
# --------------------------------------------------------------------------------------------------------
# Reading data
data_dic = read_data()
method_names = {
0: "ORIG",
1: "ORIG_CW",
2: "OVER",
3: "UNDER"
}
# Best model initialization (to be completed - manually)
# Mapping group-method -> (isTreeModel:bool, model)
models = {
"pre_ORIG": (None,None),
"pre_ORIG_CW": (None,None),
"pre_OVER": (None,None),
"pre_UNDER": (None,None),
"post_ORIG": (None,None),
"post_ORIG": (None,None),
"post_ORIG_CW": (None,None),
"post_OVER": (None,None),
"post_UNDER": (None,None),
}
# --------------------------------------------------------------------------------------------------------
# Shap value generation
# --------------------------------------------------------------------------------------------------------
shap_values = {} # Mapping group-method -> shap values
for i, group in enumerate(['pre', 'post']):
# Get test dataset based on group
X_test = data_dic['X_test_' + group]
y_test = data_dic['y_test_' + group]
for j, method in enumerate(['', '', 'over_', 'under_']):
print(f"{group}-{method_names[j]}")
# Get train dataset based on group and method
X_train = data_dic['X_train_' + method + group]
y_train = data_dic['y_train_' + method + group]
# Retrieve best model for this group-method context
model_info = models[group + '_' + method_names[j]]
is_tree = model_info[0]
model = model_info[1]
# Fit model with training data
fitted_model = model.fit(X_train, y_train) # [:500]?
# Check if we are dealing with a tree vs nn model
if is_tree:
explainer = shap.TreeExplainer(fitted_model, X_test) # [:500]?
# --------------------------------------------------------------------------------------------------------
\ No newline at end of file
......@@ -21,9 +21,9 @@ from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import ast # String to dictionary
# --------------------------------------------------------------------------------------------------------
# Reading test data
# Reading data
# --------------------------------------------------------------------------------------------------------
def read_test_data():
def read_data():
# Load test data
X_test_pre = np.load('../gen_train_data/data/output/pre/X_test_pre.npy', allow_pickle=True)
y_test_pre = np.load('../gen_train_data/data/output/pre/y_test_pre.npy', allow_pickle=True)
......@@ -152,8 +152,8 @@ def negative_recall_scorer(clf, X, y):
# --------------------------------------------------------------------------------------------------------
if __name__ == "__main__":
# Reading testing data
data_dic = read_test_data()
# Reading data
data_dic = read_data()
# Setup
# --------------------------------------------------------------------------------------------------------
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment