Commit 2d0dad71 authored by Joaquin Torres's avatar Joaquin Torres

Summary interaction plots gen for OVER

parent b7c120bf
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import pandas as pd import pandas as pd
import numpy as np import numpy as np
import shap import shap
import matplotlib.pyplot as plt
# -------------------------------------------------------------------------------------------------------- # --------------------------------------------------------------------------------------------------------
# Reading test data # Reading test data
...@@ -37,18 +38,25 @@ if __name__ == "__main__": ...@@ -37,18 +38,25 @@ if __name__ == "__main__":
2: "OVER", 2: "OVER",
3: "UNDER" 3: "UNDER"
} }
model_choices = {
"ORIG": "XGB",
"ORIG_CW": "RF",
"OVER": "XGB",
"UNDER": "XGB"
}
# -------------------------------------------------------------------------------------------------------- # --------------------------------------------------------------------------------------------------------
# Plot generation # Plot generation
# -------------------------------------------------------------------------------------------------------- # --------------------------------------------------------------------------------------------------------
for i, group in enumerate(['pre', 'post']): for j, method in enumerate(['', '', 'over_', 'under_']):
# Get test dataset based on group, add column names if method != 'over_':
X_test = data_dic['X_test_' + group] continue
y_test = data_dic['y_test_' + group] for i, group in enumerate(['pre', 'post']):
for j, method in enumerate(['', '', 'over_', 'under_']): X_test = data_dic['X_test_' + group]
y_test = data_dic['y_test_' + group]
print(f"{group}-{method_names[j]}") print(f"{group}-{method_names[j]}")
method_name = method_names[j] method_name = method_names[j]
shap_vals = np.load(f'./output/shap_values/{group}_{method_name}.npy') model_name = model_choices[method_name]
# print(f'Loaded SHAP values. Shape: {shap_vals.shape}')
shap_inter_vals = np.load(f'./output/shap_inter_values/{group}_{method_name}.npy') shap_inter_vals = np.load(f'./output/shap_inter_values/{group}_{method_name}.npy')
# print(f'Loaded SHAP INTER values. Shape: {shap_inter_vals.shape}') shap.summary_plot(shap_inter_vals, X_test, max_display=15)
plt.savefig(f'./output/plots/inters/{group}_{method_name}_{model_name}.svg', format='svg', dpi=1250)
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment