Commit 746831df authored by Joaquin Torres's avatar Joaquin Torres

Ready to generate SHAP plots

parent 866332c3
...@@ -60,9 +60,10 @@ if __name__ == "__main__": ...@@ -60,9 +60,10 @@ if __name__ == "__main__":
X_test = data_dic['X_test_' + group] X_test = data_dic['X_test_' + group]
y_test = data_dic['y_test_' + group] y_test = data_dic['y_test_' + group]
for j, method in enumerate(['', '', 'over_', 'under_']): for j, method in enumerate(['', '', 'over_', 'under_']):
if j != 1: # Remove (used to isolate RF)
print('Skip') # if j != 1:
continue # print('Skip')
# continue
print(f"{group}-{method_names[j]}") print(f"{group}-{method_names[j]}")
method_name = method_names[j] method_name = method_names[j]
model_name = model_choices[method_name] model_name = model_choices[method_name]
......
# Libraries
# --------------------------------------------------------------------------------------------------------
import pandas as pd
import numpy as np
import shap
# --------------------------------------------------------------------------------------------------------
# Reading test data
# --------------------------------------------------------------------------------------------------------
def read_test_data(attribute_names):
# Load test data
X_test_pre = np.load('../gen_train_data/data/output/pre/X_test_pre.npy', allow_pickle=True)
y_test_pre = np.load('../gen_train_data/data/output/pre/y_test_pre.npy', allow_pickle=True)
X_test_post = np.load('../gen_train_data/data/output/post/X_test_post.npy', allow_pickle=True)
y_test_post = np.load('../gen_train_data/data/output/post/y_test_post.npy', allow_pickle=True)
# Type conversion needed
data_dic = {
"X_test_pre": pd.DataFrame(X_test_pre, columns=attribute_names).convert_dtypes(),
"y_test_pre": y_test_pre,
"X_test_post": pd.DataFrame(X_test_post, columns=attribute_names).convert_dtypes(),
"y_test_post": y_test_post,
}
return data_dic
# --------------------------------------------------------------------------------------------------------
if __name__ == "__main__":
# Setup
# --------------------------------------------------------------------------------------------------------
# Retrieve attribute names in order
attribute_names = list(np.load('../gen_train_data/data/output/attributes.npy', allow_pickle=True))
# Reading data
data_dic = read_test_data(attribute_names)
method_names = {
0: "ORIG",
1: "ORIG_CW",
2: "OVER",
3: "UNDER"
}
# --------------------------------------------------------------------------------------------------------
# Plot generation
# --------------------------------------------------------------------------------------------------------
for i, group in enumerate(['pre', 'post']):
# Get test dataset based on group, add column names
X_test = data_dic['X_test_' + group]
y_test = data_dic['y_test_' + group]
for j, method in enumerate(['', '', 'over_', 'under_']):
print(f"{group}-{method_names[j]}")
method_name = method_names[j]
shap_vals = np.load(f'./output/shap_values/{group}_{method_name}.npy')
print(f'Loaded SHAP values. Shape: {shap_vals.shape}')
shap_inter_vals = np.load(f'./output/shap_inter_values/{group}_{method_name}.npy')
print(f'Loaded SHAP INTER values. Shape: {shap_inter_vals.shape}')
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment