diff --git a/explicability/output/plots/inters/post_OVER_XGB.svg b/explicability/output/plots/inters/post_OVER_XGB.svg
new file mode 100644
index 0000000000000000000000000000000000000000..cc235acff6c16d602106ac47b9306527a19cb096
--- /dev/null
+++ b/explicability/output/plots/inters/post_OVER_XGB.svg
@@ -0,0 +1,2670 @@
+
+
+
diff --git a/explicability/output/plots/inters/pre_OVER_XGB.svg b/explicability/output/plots/inters/pre_OVER_XGB.svg
new file mode 100644
index 0000000000000000000000000000000000000000..d20a868917308e8edfc84625037ff4a3366247c5
--- /dev/null
+++ b/explicability/output/plots/inters/pre_OVER_XGB.svg
@@ -0,0 +1,2606 @@
+
+
+
diff --git a/explicability/shap_plots.py b/explicability/shap_plots.py
index f7b7154bea3acbf27e674c12b13409b8fbe56fff..dce32a02cd3914ffcb2962b8768e30037325946c 100644
--- a/explicability/shap_plots.py
+++ b/explicability/shap_plots.py
@@ -3,6 +3,7 @@
import pandas as pd
import numpy as np
import shap
+import matplotlib.pyplot as plt
# --------------------------------------------------------------------------------------------------------
# Reading test data
@@ -37,18 +38,25 @@ if __name__ == "__main__":
2: "OVER",
3: "UNDER"
}
+ model_choices = {
+ "ORIG": "XGB",
+ "ORIG_CW": "RF",
+ "OVER": "XGB",
+ "UNDER": "XGB"
+ }
# --------------------------------------------------------------------------------------------------------
# Plot generation
# --------------------------------------------------------------------------------------------------------
- for i, group in enumerate(['pre', 'post']):
- # Get test dataset based on group, add column names
- X_test = data_dic['X_test_' + group]
- y_test = data_dic['y_test_' + group]
- for j, method in enumerate(['', '', 'over_', 'under_']):
+ for j, method in enumerate(['', '', 'over_', 'under_']):
+ if method != 'over_':
+ continue
+ for i, group in enumerate(['pre', 'post']):
+ X_test = data_dic['X_test_' + group]
+ y_test = data_dic['y_test_' + group]
print(f"{group}-{method_names[j]}")
method_name = method_names[j]
- shap_vals = np.load(f'./output/shap_values/{group}_{method_name}.npy')
- # print(f'Loaded SHAP values. Shape: {shap_vals.shape}')
+ model_name = model_choices[method_name]
shap_inter_vals = np.load(f'./output/shap_inter_values/{group}_{method_name}.npy')
- # print(f'Loaded SHAP INTER values. Shape: {shap_inter_vals.shape}')
+ shap.summary_plot(shap_inter_vals, X_test, max_display=15)
+ plt.savefig(f'./output/plots/inters/{group}_{method_name}_{model_name}.svg', format='svg', dpi=1250)
\ No newline at end of file