Commit 2d0dad71 authored by Joaquin Torres's avatar Joaquin Torres

Summary interaction plots gen for OVER

parent b7c120bf
This diff is collapsed.
This diff is collapsed.
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import pandas as pd import pandas as pd
import numpy as np import numpy as np
import shap import shap
import matplotlib.pyplot as plt
# -------------------------------------------------------------------------------------------------------- # --------------------------------------------------------------------------------------------------------
# Reading test data # Reading test data
...@@ -37,18 +38,25 @@ if __name__ == "__main__": ...@@ -37,18 +38,25 @@ if __name__ == "__main__":
2: "OVER", 2: "OVER",
3: "UNDER" 3: "UNDER"
} }
model_choices = {
"ORIG": "XGB",
"ORIG_CW": "RF",
"OVER": "XGB",
"UNDER": "XGB"
}
# -------------------------------------------------------------------------------------------------------- # --------------------------------------------------------------------------------------------------------
# Plot generation # Plot generation
# -------------------------------------------------------------------------------------------------------- # --------------------------------------------------------------------------------------------------------
for i, group in enumerate(['pre', 'post']): for j, method in enumerate(['', '', 'over_', 'under_']):
# Get test dataset based on group, add column names if method != 'over_':
X_test = data_dic['X_test_' + group] continue
y_test = data_dic['y_test_' + group] for i, group in enumerate(['pre', 'post']):
for j, method in enumerate(['', '', 'over_', 'under_']): X_test = data_dic['X_test_' + group]
y_test = data_dic['y_test_' + group]
print(f"{group}-{method_names[j]}") print(f"{group}-{method_names[j]}")
method_name = method_names[j] method_name = method_names[j]
shap_vals = np.load(f'./output/shap_values/{group}_{method_name}.npy') model_name = model_choices[method_name]
# print(f'Loaded SHAP values. Shape: {shap_vals.shape}')
shap_inter_vals = np.load(f'./output/shap_inter_values/{group}_{method_name}.npy') shap_inter_vals = np.load(f'./output/shap_inter_values/{group}_{method_name}.npy')
# print(f'Loaded SHAP INTER values. Shape: {shap_inter_vals.shape}') shap.summary_plot(shap_inter_vals, X_test, max_display=15)
plt.savefig(f'./output/plots/inters/{group}_{method_name}_{model_name}.svg', format='svg', dpi=1250)
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment