Commit 2d0dad71 authored by Joaquin Torres's avatar Joaquin Torres

Summary interaction plots gen for OVER

parent b7c120bf
This diff is collapsed.
This diff is collapsed.
......@@ -3,6 +3,7 @@
import pandas as pd
import numpy as np
import shap
import matplotlib.pyplot as plt
# --------------------------------------------------------------------------------------------------------
# Reading test data
......@@ -37,18 +38,25 @@ if __name__ == "__main__":
2: "OVER",
3: "UNDER"
}
model_choices = {
"ORIG": "XGB",
"ORIG_CW": "RF",
"OVER": "XGB",
"UNDER": "XGB"
}
# --------------------------------------------------------------------------------------------------------
# Plot generation
# --------------------------------------------------------------------------------------------------------
for i, group in enumerate(['pre', 'post']):
# Get test dataset based on group, add column names
X_test = data_dic['X_test_' + group]
y_test = data_dic['y_test_' + group]
for j, method in enumerate(['', '', 'over_', 'under_']):
for j, method in enumerate(['', '', 'over_', 'under_']):
if method != 'over_':
continue
for i, group in enumerate(['pre', 'post']):
X_test = data_dic['X_test_' + group]
y_test = data_dic['y_test_' + group]
print(f"{group}-{method_names[j]}")
method_name = method_names[j]
shap_vals = np.load(f'./output/shap_values/{group}_{method_name}.npy')
# print(f'Loaded SHAP values. Shape: {shap_vals.shape}')
model_name = model_choices[method_name]
shap_inter_vals = np.load(f'./output/shap_inter_values/{group}_{method_name}.npy')
# print(f'Loaded SHAP INTER values. Shape: {shap_inter_vals.shape}')
shap.summary_plot(shap_inter_vals, X_test, max_display=15)
plt.savefig(f'./output/plots/inters/{group}_{method_name}_{model_name}.svg', format='svg', dpi=1250)
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment