Commit 77db309d authored by Joaquin Torres's avatar Joaquin Torres

Completing code and regenerating plots with clean attribute names

parent c5bb9749
......@@ -4,7 +4,15 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Libraries"
"**SHAP Explainability Plots** \\\n",
"_Author: Joaquín Torres Bravo_"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Libraries"
]
},
{
......@@ -25,7 +33,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Data"
"### Data"
]
},
{
......@@ -71,6 +79,7 @@
" \"UNDER\": \"XGB\"\n",
"}\n",
"\n",
"# Load names of social and individual attributes\n",
"soc_var_names = np.load('../EDA/output/feature_names/social_factors.npy', allow_pickle=True)\n",
"ind_var_names = np.load('../EDA/output/feature_names/individual_factors.npy', allow_pickle=True)"
]
......@@ -79,7 +88,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"##### SHAP Plots"
"### SHAP Plots"
]
},
{
......@@ -89,18 +98,21 @@
"outputs": [],
"source": [
"method_name = 'OVER'\n",
"\n",
"plt.figure(figsize=(35, 75))\n",
"for i, group in enumerate(['pre', 'post']):\n",
" X_test = data_dic['X_test_' + group]\n",
" y_test = data_dic['y_test_' + group]\n",
" model_name = model_choices[method_name]\n",
" shap_vals = np.load(f'./output/shap_values/{group}_{method_name}.npy')\n",
" ax = plt.subplot(2,1,i+1)\n",
" ax = plt.subplot(2,1,i+1) # 2 rows (pre - post) 1 column\n",
" # show = False to modify plot before showing\n",
" shap.summary_plot(shap_vals, X_test, max_display=len(attribute_names), show=False)\n",
" plt.title(group.upper(), fontsize = 12, fontweight='bold')\n",
" plt.xlabel('SHAP Value')\n",
" plt.xlim(-3,5)\n",
" used_colors = {'purple': 'Social factor', 'green': 'Individual factor'}\n",
" # Modify color of attributes\n",
" for label in ax.get_yticklabels():\n",
" label_text = label.get_text() # Get the text of the label\n",
" label.set_fontsize(8)\n",
......@@ -140,13 +152,12 @@
" plt.xlim(-0.5,0.5)\n",
" used_colors = {'purple': 'Social factor', 'green': 'Individual factor'}\n",
" for label in ax.get_yticklabels():\n",
" label_text = label.get_text() # Get the text of the label\n",
" label_text = label.get_text()\n",
" label.set_fontsize(8)\n",
" if label_text in soc_var_names:\n",
" label.set_color('purple')\n",
" else:\n",
" label.set_color('green')\n",
" # Create custom legend for each subplot\n",
" handles = [mpatches.Patch(color=color, label=label) for color, label in used_colors.items()]\n",
" ax.legend(handles=handles, loc='lower right', fontsize=8)\n",
"\n",
......@@ -161,7 +172,18 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"##### SHAP Interaction Plots"
"### SHAP Interaction Plots"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**IMPORTANT NOTE**\n",
"For the code to work as intended, the SHAP source code had to be modified:\n",
"* **File**: shap/plots/_beeswarm.py\n",
"* **Line**: 591\n",
"* **Modification**: "
]
},
{
......@@ -253,7 +275,7 @@
},
{
"cell_type": "code",
"execution_count": 29,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
......@@ -262,7 +284,7 @@
},
{
"cell_type": "code",
"execution_count": 30,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
......@@ -273,7 +295,7 @@
},
{
"cell_type": "code",
"execution_count": 31,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
......@@ -397,17 +419,9 @@
},
{
"cell_type": "code",
"execution_count": 35,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Excel file has been created.\n"
]
}
],
"outputs": [],
"source": [
"# Define tolerance\n",
"tolerance = np.median(dist_matrix)\n",
......@@ -462,7 +476,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.5"
"version": "3.12.2"
}
},
"nbformat": 4,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment