shap_plots.ipynb 18.3 KB
Newer Older
1 2 3 4 5 6
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
7 8 9 10 11 12 13 14 15
    "**SHAP Explainability Plots** \\\n",
    "_Author: Joaquín Torres Bravo_"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Libraries"
16 17 18 19
   ]
  },
  {
   "cell_type": "code",
20
   "execution_count": null,
21 22 23 24 25 26
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import shap\n",
27
    "import matplotlib.pyplot as plt\n",
28 29
    "import matplotlib.patches as mpatches\n",
    "import seaborn as sns"
30 31 32 33 34 35
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
36
    "### Data"
37 38 39 40
   ]
  },
  {
   "cell_type": "code",
41
   "execution_count": null,
42 43 44 45
   "metadata": {},
   "outputs": [],
   "source": [
    "# Retrieve attribute names in order\n",
Joaquin Torres's avatar
Joaquin Torres committed
46
    "attribute_names = attribute_names = list(np.load('../EDA/output/feature_names/all_features.npy', allow_pickle=True))\n",
47 48
    "\n",
    "# Load test data\n",
Joaquin Torres's avatar
Joaquin Torres committed
49 50 51 52
    "X_test_pre = np.load('../gen_train_data/output/pre/X_test_pre.npy', allow_pickle=True)\n",
    "y_test_pre = np.load('../gen_train_data/output/pre/y_test_pre.npy', allow_pickle=True)\n",
    "X_test_post = np.load('../gen_train_data/output/post/X_test_post.npy', allow_pickle=True)\n",
    "y_test_post = np.load('../gen_train_data/output/post/y_test_post.npy', allow_pickle=True)\n",
53 54 55 56 57 58 59 60 61 62 63 64
    "\n",
    "# Type conversion needed    \n",
    "data_dic = {\n",
    "    \"X_test_pre\": pd.DataFrame(X_test_pre, columns=attribute_names).convert_dtypes(),\n",
    "    \"y_test_pre\": y_test_pre,\n",
    "    \"X_test_post\": pd.DataFrame(X_test_post, columns=attribute_names).convert_dtypes(),\n",
    "    \"y_test_post\": y_test_post,\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
65
   "execution_count": null,
66 67 68 69 70 71 72 73 74 75 76 77 78 79
   "metadata": {},
   "outputs": [],
   "source": [
    "method_names = {\n",
    "    0: \"ORIG\",\n",
    "    1: \"ORIG_CW\",\n",
    "    2: \"OVER\",\n",
    "    3: \"UNDER\"\n",
    "}\n",
    "model_choices = {\n",
    "    \"ORIG\": \"XGB\",\n",
    "    \"ORIG_CW\": \"RF\",\n",
    "    \"OVER\": \"XGB\",\n",
    "    \"UNDER\": \"XGB\"\n",
80 81
    "}\n",
    "\n",
82
    "# Load names of social and individual attributes\n",
Joaquin Torres's avatar
Joaquin Torres committed
83 84
    "soc_var_names = np.load('../EDA/output/feature_names/social_factors.npy', allow_pickle=True)\n",
    "ind_var_names = np.load('../EDA/output/feature_names/individual_factors.npy', allow_pickle=True)"
85 86 87 88 89 90
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
91
    "### SHAP Plots"
92 93 94 95
   ]
  },
  {
   "cell_type": "code",
96
   "execution_count": null,
97
   "metadata": {},
98
   "outputs": [],
99
   "source": [
100
    "method_name = 'OVER'\n",
101
    "\n",
102
    "plt.figure(figsize=(35, 75))\n",
103 104 105 106 107
    "for i, group in enumerate(['pre', 'post']):\n",
    "            X_test = data_dic['X_test_' + group]\n",
    "            y_test = data_dic['y_test_' + group]\n",
    "            model_name = model_choices[method_name]\n",
    "            shap_vals = np.load(f'./output/shap_values/{group}_{method_name}.npy')\n",
108 109
    "            ax = plt.subplot(2,1,i+1) # 2 rows (pre - post) 1 column\n",
    "            # show = False to modify plot before showing\n",
110
    "            shap.summary_plot(shap_vals, X_test, max_display=len(attribute_names), show=False)\n",
111 112
    "            plt.title(group.upper(), fontsize = 12, fontweight='bold')\n",
    "            plt.xlabel('SHAP Value')\n",
113
    "            plt.xlim(-3,5)\n",
114
    "            used_colors = {'purple': 'Social factor', 'green': 'Individual factor'}\n",
115
    "            # Modify color of attributes\n",
116 117 118 119 120 121 122 123 124 125
    "            for label in ax.get_yticklabels():\n",
    "                label_text = label.get_text()  # Get the text of the label\n",
    "                label.set_fontsize(8)\n",
    "                if label_text in soc_var_names:\n",
    "                        label.set_color('purple')\n",
    "                else:\n",
    "                        label.set_color('green')\n",
    "                # Create custom legend for each subplot\n",
    "                handles = [mpatches.Patch(color=color, label=label) for color, label in used_colors.items()]\n",
    "                ax.legend(handles=handles, loc='lower right', fontsize=8)\n",
126
    "            \n",
127 128
    "plt.suptitle(f'SHAP Summary Plots PRE vs POST - Pipeline: Oversampling - Model: {model_name}\\n\\n')\n",
    "plt.subplots_adjust(wspace=1)\n",
129
    "plt.tight_layout()\n",
Joaquin Torres's avatar
Joaquin Torres committed
130 131
    "plt.savefig(f'./output/plots/shap_summary/{method_name}_{model_name}.svg', format='svg', dpi=1250)\n",
    "plt.show()"
132
   ]
133 134 135
  },
  {
   "cell_type": "code",
136
   "execution_count": null,
137
   "metadata": {},
138
   "outputs": [],
139
   "source": [
Joaquin Torres's avatar
Joaquin Torres committed
140 141 142 143 144 145 146 147 148 149 150 151 152 153 154
    "method_name = 'ORIG_CW'\n",
    "plt.figure(figsize=(35, 75))\n",
    "for i, group in enumerate(['pre', 'post']):\n",
    "            X_test = data_dic['X_test_' + group]\n",
    "            y_test = data_dic['y_test_' + group]\n",
    "            model_name = model_choices[method_name]\n",
    "            shap_vals = np.load(f'./output/shap_values/{group}_{method_name}.npy')\n",
    "            shap_vals = shap_vals[:,:,1] # Select shap values for positive class\n",
    "            ax = plt.subplot(2,1,i+1)\n",
    "            shap.summary_plot(shap_vals, X_test, max_display=len(attribute_names), show=False)\n",
    "            plt.title(group.upper(), fontsize = 12, fontweight='bold')\n",
    "            plt.xlabel('SHAP Value')\n",
    "            plt.xlim(-0.5,0.5)\n",
    "            used_colors = {'purple': 'Social factor', 'green': 'Individual factor'}\n",
    "            for label in ax.get_yticklabels():\n",
155
    "                label_text = label.get_text()\n",
Joaquin Torres's avatar
Joaquin Torres committed
156 157 158 159 160 161 162
    "                label.set_fontsize(8)\n",
    "                if label_text in soc_var_names:\n",
    "                        label.set_color('purple')\n",
    "                else:\n",
    "                        label.set_color('green')\n",
    "                handles = [mpatches.Patch(color=color, label=label) for color, label in used_colors.items()]\n",
    "                ax.legend(handles=handles, loc='lower right', fontsize=8)\n",
163
    "\n",
Joaquin Torres's avatar
Joaquin Torres committed
164 165 166
    "plt.suptitle(f'SHAP Summary Plots PRE vs POST - Pipeline: Original with Class Weight - Model: {model_name}\\n\\n')\n",
    "plt.subplots_adjust(wspace=1)\n",
    "plt.tight_layout()\n",
Joaquin Torres's avatar
Joaquin Torres committed
167 168
    "plt.savefig(f'./output/plots/shap_summary/{method_name}_{model_name}.svg', format='svg', dpi=1250)\n",
    "plt.show()"
169
   ]
170 171 172 173 174
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
175 176 177 178 179 180 181 182 183 184 185 186
    "### SHAP Interaction Plots"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**IMPORTANT NOTE**\n",
    "For the code to work as intended, the SHAP source code had to be modified:\n",
    "* **File**: shap/plots/_beeswarm.py\n",
    "* **Line**: 591\n",
    "* **Modification**: "
187 188 189 190 191 192
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
Joaquin Torres's avatar
Joaquin Torres committed
193
    "Had to modify beeswarm.py (line 591)"
194 195 196 197
   ]
  },
  {
   "cell_type": "code",
198
   "execution_count": null,
199 200 201
   "metadata": {},
   "outputs": [],
   "source": [
Joaquin Torres's avatar
Joaquin Torres committed
202
    "method_name = 'ORIG_CW'\n",
203
    "group = 'pre'"
204 205 206 207
   ]
  },
  {
   "cell_type": "code",
208
   "execution_count": null,
209
   "metadata": {},
210
   "outputs": [],
211 212 213 214 215 216
   "source": [
    "X_test = data_dic['X_test_' + group]\n",
    "y_test = data_dic['y_test_' + group]\n",
    "model_name = model_choices[method_name]\n",
    "\n",
    "shap_inter_vals = np.load(f'./output/shap_inter_values/{group}_{method_name}.npy')\n",
Joaquin Torres's avatar
Joaquin Torres committed
217 218 219 220 221 222 223 224 225 226 227 228 229
    "if method_name == 'ORIG_CW':\n",
    "      shap_inter_vals = shap_inter_vals[:,:,:,1] # Take info about positive class\n",
    "\n",
    "num_instances = shap_inter_vals.shape[0]  # Dynamically get the number of instances\n",
    "num_features = shap_inter_vals.shape[1]  # Assuming the number of features is the second dimension size\n",
    "# Loop over each instance and set the diagonal and lower triangle of each 39x39 matrix to NaN\n",
    "for i in range(num_instances):\n",
    "    # Mask the diagonal\n",
    "    np.fill_diagonal(shap_inter_vals[i], np.nan)\n",
    "    # Get indices for the lower triangle, excluding the diagonal\n",
    "    lower_triangle_indices = np.tril_indices(num_features, -1)  # -1 excludes the diagonal\n",
    "    # Set the lower triangle to NaN\n",
    "    shap_inter_vals[i][lower_triangle_indices] = np.nan\n",
230
    "\n",
231 232
    "plt.figure(figsize=(10,10))\n",
    "shap.summary_plot(shap_inter_vals, X_test, show=False, sort=False)\n",
233 234
    "fig=plt.gcf()\n",
    "attr_names = []\n",
Joaquin Torres's avatar
Joaquin Torres committed
235
    "used_colors = {'purple': 'Social factor', 'green': 'Individual factor'}\n",
236 237 238 239 240 241
    "# Iterate over all axes in the figure\n",
    "for ax in fig.get_axes():\n",
    "    # Customize the y-axis tick labels\n",
    "    for label in ax.get_yticklabels():\n",
    "        label_text = label.get_text()  # Get the text of the label\n",
    "        attr_names.append(label_text)\n",
Joaquin Torres's avatar
Joaquin Torres committed
242
    "        label.set_fontsize(12)\n",
243 244 245 246 247 248 249 250 251 252
    "        if label_text in soc_var_names:\n",
    "                label.set_color('purple')\n",
    "        else:\n",
    "                label.set_color('green')\n",
    "\n",
    "# Assuming the top labels are treated as titles, let's try to modify them\n",
    "total_axes = len(fig.axes)\n",
    "for i, ax in enumerate(fig.axes):\n",
    "        reverse_index = total_axes - 1 - i\n",
    "        title = attr_names[reverse_index]\n",
Joaquin Torres's avatar
Joaquin Torres committed
253
    "        ax.set_title(title, color='purple' if title in soc_var_names else 'green', fontsize=12, rotation=90)\n",
Joaquin Torres's avatar
Joaquin Torres committed
254
    "        if method_name == 'ORIG_CW':\n",
Joaquin Torres's avatar
Joaquin Torres committed
255
    "              ax.set_xlim(-0.15, 0.15) # Use same scale for pre and post\n",
Joaquin Torres's avatar
Joaquin Torres committed
256
    "        elif method_name == 'OVER':\n",
Joaquin Torres's avatar
Joaquin Torres committed
257
    "              ax.set_xlim(-2,2)\n",
258 259 260
    "\n",
    "# Create a single general legend for the whole figure\n",
    "handles = [mpatches.Patch(color=color, label=label) for color, label in used_colors.items()]\n",
261
    "fig.legend(handles=handles, loc='lower right', fontsize=12)\n",
262
    "\n",
263
    "plt.suptitle(f'Simplified Example SHAP Summary Interaction Plot\\n', fontsize=15, fontweight='bold', x=0.5, y=0.95, ha='center')\n",
264
    "plt.tight_layout()\n",
265
    "plt.savefig(f'./output/plots/shap_inter_summary/example.svg', format='svg', dpi=600)\n",
266 267
    "plt.show()"
   ]
268 269 270 271 272 273 274 275 276 277
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### SHAP Interaction Analysis"
   ]
  },
  {
   "cell_type": "code",
278
   "execution_count": null,
279 280 281 282 283 284 285 286
   "metadata": {},
   "outputs": [],
   "source": [
    "method_name = 'OVER'"
   ]
  },
  {
   "cell_type": "code",
287
   "execution_count": null,
288 289 290 291 292 293 294 295 296 297
   "metadata": {},
   "outputs": [],
   "source": [
    "# Normalization function\n",
    "def max_absolute_scale(matrix):\n",
    "    return matrix / np.max(np.abs(matrix))"
   ]
  },
  {
   "cell_type": "code",
298
   "execution_count": null,
299 300 301 302 303 304 305
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load array of shap inter matrices for pre and post for the chosen method\n",
    "shap_inter_vals_pre = np.load(f'./output/shap_inter_values/pre_{method_name}.npy')\n",
    "shap_inter_vals_post = np.load(f'./output/shap_inter_values/post_{method_name}.npy')\n",
    "if method_name == 'ORIG_CW':\n",
306
    "    shap_inter_vals_pre = shap_inter_vals_pre[:,:,:,1]\n",
307 308 309 310 311 312 313 314 315 316 317
    "    shap_inter_vals_post = shap_inter_vals_post[:,:,:,1]\n",
    "\n",
    "# Normalize each matrix in each of the two arrays\n",
    "norm_shap_inter_vals_pre = np.array([max_absolute_scale(matrix) for matrix in shap_inter_vals_pre])\n",
    "norm_shap_inter_vals_post = np.array([max_absolute_scale(matrix) for matrix in shap_inter_vals_post])\n",
    "\n",
    "# Aggregate matrices in each group by calculating the mean\n",
    "agg_shap_inter_vals_pre = np.mean(norm_shap_inter_vals_pre, axis=0)\n",
    "agg_shap_inter_vals_post = np.mean(norm_shap_inter_vals_post, axis=0)\n",
    "\n",
    "# Compute the difference between the aggregated matrices and take the absolute value\n",
318
    "dist_matrix = np.abs(agg_shap_inter_vals_post - agg_shap_inter_vals_pre)"
319 320 321 322
   ]
  },
  {
   "cell_type": "code",
323
   "execution_count": null,
324 325 326 327 328 329 330 331 332
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_inter_heatmap(matrix, feature_names, soc_var_names, method_name, group):\n",
    "    # Create a mask for the upper triangle\n",
    "    mask = np.triu(np.ones_like(matrix, dtype=bool))\n",
    "\n",
    "    # Create the heatmap using Seaborn\n",
    "    plt.figure(figsize=(12, 12))\n",
Joaquin Torres's avatar
Joaquin Torres committed
333
    "    ax = sns.heatmap(matrix, mask=mask, cmap='coolwarm', center=0, annot=False, cbar=True,\n",
334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362
    "                xticklabels=feature_names, yticklabels=feature_names)\n",
    "    \n",
    "    for tick_label in ax.get_yticklabels():\n",
    "        if tick_label.get_text() in soc_var_names:\n",
    "            tick_label.set_color('purple')  # Specific social variables\n",
    "        else:\n",
    "            tick_label.set_color('green')  # Other variables\n",
    "    \n",
    "    for tick_label in ax.get_xticklabels():\n",
    "        if tick_label.get_text() in soc_var_names:\n",
    "            tick_label.set_color('purple')  # Specific social variables\n",
    "        else:\n",
    "            tick_label.set_color('green')  # Other variables\n",
    "        \n",
    "    plt.title(f'Mean Interaction Matrix - Pipeline: {method_name} - Group: {group}\\n', fontdict={'fontstyle': 'normal', 'weight': 'bold'})\n",
    "    # Create a custom legend\n",
    "    purple_patch = mpatches.Patch(color='purple', label='Social factor')\n",
    "    green_patch = mpatches.Patch(color='green', label='Individual factor')\n",
    "    plt.legend(handles=[purple_patch, green_patch], loc='upper right')\n",
    "    # Add a title to the color bar\n",
    "    cbar = ax.collections[0].colorbar\n",
    "    cbar.set_label('Normalized SHAP Interaction Value', labelpad=15, rotation=270, verticalalignment='bottom')\n",
    "    \n",
    "    plt.savefig(f'./output/plots/heatmaps_interactions/{method_name}_{group}.svg', format='svg', dpi=600)\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
363
   "execution_count": null,
364
   "metadata": {},
365
   "outputs": [],
366
   "source": [
Joaquin Torres's avatar
Joaquin Torres committed
367
    "plot_inter_heatmap(agg_shap_inter_vals_post, attribute_names, soc_var_names, method_name, 'POST')"
368 369 370 371
   ]
  },
  {
   "cell_type": "code",
372
   "execution_count": null,
373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_distance_heatmap(matrix, feature_names, soc_var_names, method_name):\n",
    "    # Create a mask for the upper triangle\n",
    "    mask = np.triu(np.ones_like(matrix, dtype=bool))\n",
    "\n",
    "    # Create the heatmap using Seaborn\n",
    "    plt.figure(figsize=(12, 12))\n",
    "    ax = sns.heatmap(matrix, mask=mask, cmap=sns.color_palette(\"light:r\", as_cmap=True), annot=False, cbar=True,\n",
    "                xticklabels=feature_names, yticklabels=feature_names)\n",
    "\n",
    "    for tick_label in ax.get_yticklabels():\n",
    "        if tick_label.get_text() in soc_var_names:\n",
    "            tick_label.set_color('purple')  # Specific social variables\n",
    "        else:\n",
    "            tick_label.set_color('green')  # Other variables\n",
    "    \n",
    "    for tick_label in ax.get_xticklabels():\n",
    "        if tick_label.get_text() in soc_var_names:\n",
    "            tick_label.set_color('purple')  # Specific social variables\n",
    "        else:\n",
    "            tick_label.set_color('green')  # Other variables\n",
    "        \n",
    "    plt.title(f'Distance Interaction Matrix between PRE and POST - Pipeline: {method_name}\\n', fontdict={'fontstyle': 'normal', 'weight': 'bold'})\n",
    "    # Create a custom legend\n",
    "    purple_patch = mpatches.Patch(color='purple', label='Social factor')\n",
    "    green_patch = mpatches.Patch(color='green', label='Individual factor')\n",
    "    plt.legend(handles=[purple_patch, green_patch], loc='upper right')\n",
    "    # Add a title to the color bar\n",
    "    cbar = ax.collections[0].colorbar\n",
    "    cbar.set_label('Abs (PRE - POST)', labelpad=15, rotation=270, verticalalignment='bottom')\n",
    "\n",
    "    plt.savefig(f'./output/plots/heatmaps_interactions/DIST_{method_name}.svg', format='svg', dpi=600)\n",
    "    \n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
413
   "execution_count": null,
414
   "metadata": {},
415
   "outputs": [],
416
   "source": [
417
    "plot_distance_heatmap(dist_matrix, attribute_names, soc_var_names, method_name)"
418
   ]
Joaquin Torres's avatar
Joaquin Torres committed
419 420 421
  },
  {
   "cell_type": "code",
422
   "execution_count": null,
Joaquin Torres's avatar
Joaquin Torres committed
423
   "metadata": {},
424
   "outputs": [],
425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460
   "source": [
    "# Define tolerance\n",
    "tolerance = np.median(dist_matrix)\n",
    "# Create a DataFrame to hold the interactions\n",
    "interactions = []\n",
    "\n",
    "# Iterate over the matrix to extract interactions above the tolerance\n",
    "for i in range(1, dist_matrix.shape[0]):\n",
    "    for j in range(i):  # Lower triangle exclduing diagonal\n",
    "        if dist_matrix[i, j] > tolerance:\n",
    "            var1 = attribute_names[i]\n",
    "            var2 = attribute_names[j]\n",
    "            if var1 in soc_var_names and var2 in soc_var_names:\n",
    "                inter_type = 'Social'\n",
    "            elif var1 in ind_var_names and var2 in ind_var_names:\n",
    "                inter_type = 'Individual'\n",
    "            else:\n",
    "                inter_type = 'Mixed'\n",
    "            interactions.append({\n",
    "                'Variable 1': var1, \n",
    "                'Variable 2': var2,\n",
    "                'SHAP Inter Variation PRE-POST': dist_matrix[i, j],\n",
    "                'Interaction Type': inter_type\n",
    "            })\n",
    "\n",
    "# Convert the list of dictionaries to a DataFrame\n",
    "interactions_df = pd.DataFrame(interactions)\n",
    "\n",
    "# Sort the DataFrame by 'Interaction Strength' in descending order\n",
    "sorted_interactions_df = interactions_df.sort_values(by='SHAP Inter Variation PRE-POST', ascending=False)\n",
    "\n",
    "# Export to Excel\n",
    "sorted_interactions_df.to_excel(f'./output/inter_variation_{method_name}.xlsx', index=False)\n",
    "\n",
    "print(\"Excel file has been created.\")"
   ]
461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
479
   "version": "3.12.2"
480 481 482 483 484
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}