shap_plots.ipynb 19 KB
Newer Older
1 2 3 4 5 6
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
7 8 9 10 11 12 13 14 15
    "**SHAP Explainability Plots** \\\n",
    "_Author: Joaquín Torres Bravo_"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Libraries"
16 17 18 19
   ]
  },
  {
   "cell_type": "code",
Joaquin Torres's avatar
Joaquin Torres committed
20
   "execution_count": null,
21 22 23 24 25 26
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import shap\n",
27
    "import matplotlib.pyplot as plt\n",
Joaquin Torres's avatar
Joaquin Torres committed
28
    "import matplotlib.patches as mpatches # Legend\n",
29
    "import seaborn as sns"
30 31 32 33 34 35
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
36
    "### Data"
37 38 39 40
   ]
  },
  {
   "cell_type": "code",
Joaquin Torres's avatar
Joaquin Torres committed
41
   "execution_count": null,
42
   "metadata": {},
Joaquin Torres's avatar
Joaquin Torres committed
43
   "outputs": [],
44 45
   "source": [
    "# Retrieve attribute names in order\n",
Joaquin Torres's avatar
Joaquin Torres committed
46
    "attribute_names = attribute_names = list(np.load('../EDA/output/feature_names/all_features.npy', allow_pickle=True))\n",
47 48
    "\n",
    "# Load test data\n",
Joaquin Torres's avatar
Joaquin Torres committed
49 50 51 52
    "X_test_pre = np.load('../gen_train_data/output/pre/X_test_pre.npy', allow_pickle=True)\n",
    "y_test_pre = np.load('../gen_train_data/output/pre/y_test_pre.npy', allow_pickle=True)\n",
    "X_test_post = np.load('../gen_train_data/output/post/X_test_post.npy', allow_pickle=True)\n",
    "y_test_post = np.load('../gen_train_data/output/post/y_test_post.npy', allow_pickle=True)\n",
53 54 55 56 57 58 59 60 61 62 63 64
    "\n",
    "# Type conversion needed    \n",
    "data_dic = {\n",
    "    \"X_test_pre\": pd.DataFrame(X_test_pre, columns=attribute_names).convert_dtypes(),\n",
    "    \"y_test_pre\": y_test_pre,\n",
    "    \"X_test_post\": pd.DataFrame(X_test_post, columns=attribute_names).convert_dtypes(),\n",
    "    \"y_test_post\": y_test_post,\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
Joaquin Torres's avatar
Joaquin Torres committed
65
   "execution_count": null,
66 67 68 69 70 71 72 73 74 75 76 77 78 79
   "metadata": {},
   "outputs": [],
   "source": [
    "method_names = {\n",
    "    0: \"ORIG\",\n",
    "    1: \"ORIG_CW\",\n",
    "    2: \"OVER\",\n",
    "    3: \"UNDER\"\n",
    "}\n",
    "model_choices = {\n",
    "    \"ORIG\": \"XGB\",\n",
    "    \"ORIG_CW\": \"RF\",\n",
    "    \"OVER\": \"XGB\",\n",
    "    \"UNDER\": \"XGB\"\n",
80 81
    "}\n",
    "\n",
82
    "# Load names of social and individual attributes\n",
Joaquin Torres's avatar
Joaquin Torres committed
83 84
    "soc_var_names = np.load('../EDA/output/feature_names/social_factors.npy', allow_pickle=True)\n",
    "ind_var_names = np.load('../EDA/output/feature_names/individual_factors.npy', allow_pickle=True)"
85 86 87 88 89 90
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
91
    "### SHAP Plots"
92 93 94 95
   ]
  },
  {
   "cell_type": "code",
96
   "execution_count": null,
97
   "metadata": {},
98
   "outputs": [],
99
   "source": [
100
    "method_name = 'OVER'\n",
101
    "\n",
102
    "plt.figure(figsize=(35, 75))\n",
103 104 105 106 107
    "for i, group in enumerate(['pre', 'post']):\n",
    "            X_test = data_dic['X_test_' + group]\n",
    "            y_test = data_dic['y_test_' + group]\n",
    "            model_name = model_choices[method_name]\n",
    "            shap_vals = np.load(f'./output/shap_values/{group}_{method_name}.npy')\n",
108 109
    "            ax = plt.subplot(2,1,i+1) # 2 rows (pre - post) 1 column\n",
    "            # show = False to modify plot before showing\n",
110
    "            shap.summary_plot(shap_vals, X_test, max_display=len(attribute_names), show=False)\n",
111 112
    "            plt.title(group.upper(), fontsize = 12, fontweight='bold')\n",
    "            plt.xlabel('SHAP Value')\n",
113
    "            plt.xlim(-3,5)\n",
114
    "            used_colors = {'purple': 'Social factor', 'green': 'Individual factor'}\n",
115
    "            # Modify color of attributes\n",
116 117 118 119 120 121 122 123 124 125
    "            for label in ax.get_yticklabels():\n",
    "                label_text = label.get_text()  # Get the text of the label\n",
    "                label.set_fontsize(8)\n",
    "                if label_text in soc_var_names:\n",
    "                        label.set_color('purple')\n",
    "                else:\n",
    "                        label.set_color('green')\n",
    "                # Create custom legend for each subplot\n",
    "                handles = [mpatches.Patch(color=color, label=label) for color, label in used_colors.items()]\n",
    "                ax.legend(handles=handles, loc='lower right', fontsize=8)\n",
126
    "            \n",
127 128
    "plt.suptitle(f'SHAP Summary Plots PRE vs POST - Pipeline: Oversampling - Model: {model_name}\\n\\n')\n",
    "plt.subplots_adjust(wspace=1)\n",
129
    "plt.tight_layout()\n",
Joaquin Torres's avatar
Joaquin Torres committed
130 131
    "plt.savefig(f'./output/plots/shap_summary/{method_name}_{model_name}.svg', format='svg', dpi=1250)\n",
    "plt.show()"
132
   ]
133 134 135
  },
  {
   "cell_type": "code",
Joaquin Torres's avatar
Joaquin Torres committed
136
   "execution_count": null,
137
   "metadata": {},
Joaquin Torres's avatar
Joaquin Torres committed
138
   "outputs": [],
139
   "source": [
Joaquin Torres's avatar
Joaquin Torres committed
140 141 142 143 144 145 146 147 148 149 150 151 152 153 154
    "method_name = 'ORIG_CW'\n",
    "plt.figure(figsize=(35, 75))\n",
    "for i, group in enumerate(['pre', 'post']):\n",
    "            X_test = data_dic['X_test_' + group]\n",
    "            y_test = data_dic['y_test_' + group]\n",
    "            model_name = model_choices[method_name]\n",
    "            shap_vals = np.load(f'./output/shap_values/{group}_{method_name}.npy')\n",
    "            shap_vals = shap_vals[:,:,1] # Select shap values for positive class\n",
    "            ax = plt.subplot(2,1,i+1)\n",
    "            shap.summary_plot(shap_vals, X_test, max_display=len(attribute_names), show=False)\n",
    "            plt.title(group.upper(), fontsize = 12, fontweight='bold')\n",
    "            plt.xlabel('SHAP Value')\n",
    "            plt.xlim(-0.5,0.5)\n",
    "            used_colors = {'purple': 'Social factor', 'green': 'Individual factor'}\n",
    "            for label in ax.get_yticklabels():\n",
155
    "                label_text = label.get_text()\n",
Joaquin Torres's avatar
Joaquin Torres committed
156 157 158 159 160 161 162
    "                label.set_fontsize(8)\n",
    "                if label_text in soc_var_names:\n",
    "                        label.set_color('purple')\n",
    "                else:\n",
    "                        label.set_color('green')\n",
    "                handles = [mpatches.Patch(color=color, label=label) for color, label in used_colors.items()]\n",
    "                ax.legend(handles=handles, loc='lower right', fontsize=8)\n",
163
    "\n",
Joaquin Torres's avatar
Joaquin Torres committed
164 165 166
    "plt.suptitle(f'SHAP Summary Plots PRE vs POST - Pipeline: Original with Class Weight - Model: {model_name}\\n\\n')\n",
    "plt.subplots_adjust(wspace=1)\n",
    "plt.tight_layout()\n",
Joaquin Torres's avatar
Joaquin Torres committed
167 168
    "plt.savefig(f'./output/plots/shap_summary/{method_name}_{model_name}.svg', format='svg', dpi=1250)\n",
    "plt.show()"
169
   ]
170 171 172 173 174
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
175 176 177 178 179 180 181
    "### SHAP Interaction Plots"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
Joaquin Torres's avatar
Joaquin Torres committed
182 183
    "**IMPORTANT NOTE** \\\n",
    "For the code to work as intended, the SHAP source code had to be modified: .venv/lib/python3.9/site-packages/shap/plots/_beeswarm.py \n",
184 185 186 187 188 189
    "\n",
    "sort_inds = np.argsort(-np.abs(shap_values.sum(1)).sum(0)) \\\n",
    "**replaced by** \\\n",
    "sort_inds = np.arange(39)\n",
    "\n",
    "The idea is to display the features in the original order instead of sorting them according to their absolute SHAP impact, as the library originally does\n"
190 191 192 193
   ]
  },
  {
   "cell_type": "code",
Joaquin Torres's avatar
Joaquin Torres committed
194
   "execution_count": null,
195
   "metadata": {},
Joaquin Torres's avatar
Joaquin Torres committed
196
   "outputs": [],
197
   "source": [
198 199 200 201 202
    "for group in ['pre', 'post']:\n",
    "    for method_name in ['ORIG_CW', 'OVER']:\n",
    "        X_test = data_dic['X_test_' + group]\n",
    "        y_test = data_dic['y_test_' + group]\n",
    "        model_name = model_choices[method_name]\n",
203
    "\n",
204 205 206
    "        shap_inter_vals = np.load(f'./output/shap_inter_values/{group}_{method_name}.npy')\n",
    "        if method_name == 'ORIG_CW':\n",
    "            shap_inter_vals = shap_inter_vals[:,:,:,1]  # Take info about positive class\n",
Joaquin Torres's avatar
Joaquin Torres committed
207
    "\n",
208 209
    "        num_instances = shap_inter_vals.shape[0]  # Dynamically get the number of instances\n",
    "        num_features = shap_inter_vals.shape[1]  # Assuming the number of features is the second dimension size\n",
210
    "\n",
211 212 213 214 215 216 217 218
    "        # Loop over each instance and set the diagonal and lower triangle of each 39x39 matrix to NaN\n",
    "        for i in range(num_instances):\n",
    "            # Mask the diagonal\n",
    "            np.fill_diagonal(shap_inter_vals[i], np.nan)\n",
    "            # Get indices for the lower triangle, excluding the diagonal\n",
    "            lower_triangle_indices = np.tril_indices(num_features, -1)  # -1 excludes the diagonal\n",
    "            # Set the lower triangle to NaN\n",
    "            shap_inter_vals[i][lower_triangle_indices] = np.nan\n",
219
    "\n",
220 221 222 223 224
    "        plt.figure(figsize=(10,10))\n",
    "        shap.summary_plot(shap_inter_vals, X_test, show=False, sort=False, max_display=len(attribute_names))\n",
    "        fig = plt.gcf()\n",
    "        attr_names = []\n",
    "        used_colors = {'purple': 'Social factor', 'green': 'Individual factor'}\n",
225
    "\n",
226 227 228 229 230 231 232 233 234 235 236
    "        # Iterate over all axes in the figure\n",
    "        for ax in fig.get_axes():\n",
    "            # Customize the y-axis tick labels\n",
    "            for label in ax.get_yticklabels():\n",
    "                label_text = label.get_text()  # Get the text of the label\n",
    "                attr_names.append(label_text)\n",
    "                label.set_fontsize(12)\n",
    "                if label_text in soc_var_names:\n",
    "                    label.set_color('purple')\n",
    "                else:\n",
    "                    label.set_color('green')\n",
237
    "\n",
238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257
    "        # Assuming the top labels are treated as titles, let's try to modify them\n",
    "        total_axes = len(fig.axes)\n",
    "        for i, ax in enumerate(fig.axes):\n",
    "            reverse_index = total_axes - 1 - i\n",
    "            title = attr_names[reverse_index]\n",
    "            ax.set_title(title, color='purple' if title in soc_var_names else 'green', fontsize=12, rotation=90)\n",
    "            if method_name == 'ORIG_CW':\n",
    "                ax.set_xlim(-0.15, 0.15)  # Use same scale for pre and post\n",
    "            elif method_name == 'OVER':\n",
    "                ax.set_xlim(-2, 2)\n",
    "\n",
    "        # Create a single general legend for the whole figure\n",
    "        handles = [mpatches.Patch(color=color, label=label) for color, label in used_colors.items()]\n",
    "        fig.legend(handles=handles, loc='lower right', fontsize=12)\n",
    "\n",
    "        # plt.suptitle(f'Simplified Example SHAP Summary Interaction Plot\\n', fontsize=15, fontweight='bold', x=0.5, y=0.95, ha='center')\n",
    "        plt.suptitle(f'SHAP Summary Interaction Plot - {method_name} - {str.upper(group)}\\n', fontsize=20, fontweight='bold') #, x=0.5, y=0.95, ha='center'\n",
    "        plt.tight_layout()\n",
    "        plt.savefig(f'./output/plots/shap_inter_summary/{str.upper(group)}_{method_name}_{model_name}.svg', format='svg', dpi=700)\n",
    "        plt.show()"
258
   ]
259 260 261 262 263
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
264
    "#### SHAP Interaction Analysis (To Be Confirmed)"
265 266 267 268
   ]
  },
  {
   "cell_type": "code",
269
   "execution_count": null,
270 271 272 273 274 275 276 277
   "metadata": {},
   "outputs": [],
   "source": [
    "method_name = 'OVER'"
   ]
  },
  {
   "cell_type": "code",
278
   "execution_count": null,
279 280 281 282 283 284 285 286 287 288
   "metadata": {},
   "outputs": [],
   "source": [
    "# Normalization function\n",
    "def max_absolute_scale(matrix):\n",
    "    return matrix / np.max(np.abs(matrix))"
   ]
  },
  {
   "cell_type": "code",
289
   "execution_count": null,
290 291 292 293 294 295 296
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load array of shap inter matrices for pre and post for the chosen method\n",
    "shap_inter_vals_pre = np.load(f'./output/shap_inter_values/pre_{method_name}.npy')\n",
    "shap_inter_vals_post = np.load(f'./output/shap_inter_values/post_{method_name}.npy')\n",
    "if method_name == 'ORIG_CW':\n",
297
    "    shap_inter_vals_pre = shap_inter_vals_pre[:,:,:,1]\n",
298 299 300 301 302 303 304 305 306 307 308
    "    shap_inter_vals_post = shap_inter_vals_post[:,:,:,1]\n",
    "\n",
    "# Normalize each matrix in each of the two arrays\n",
    "norm_shap_inter_vals_pre = np.array([max_absolute_scale(matrix) for matrix in shap_inter_vals_pre])\n",
    "norm_shap_inter_vals_post = np.array([max_absolute_scale(matrix) for matrix in shap_inter_vals_post])\n",
    "\n",
    "# Aggregate matrices in each group by calculating the mean\n",
    "agg_shap_inter_vals_pre = np.mean(norm_shap_inter_vals_pre, axis=0)\n",
    "agg_shap_inter_vals_post = np.mean(norm_shap_inter_vals_post, axis=0)\n",
    "\n",
    "# Compute the difference between the aggregated matrices and take the absolute value\n",
309
    "dist_matrix = np.abs(agg_shap_inter_vals_post - agg_shap_inter_vals_pre)"
310 311 312 313
   ]
  },
  {
   "cell_type": "code",
314
   "execution_count": null,
315 316 317 318 319 320 321 322 323
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_inter_heatmap(matrix, feature_names, soc_var_names, method_name, group):\n",
    "    # Create a mask for the upper triangle\n",
    "    mask = np.triu(np.ones_like(matrix, dtype=bool))\n",
    "\n",
    "    # Create the heatmap using Seaborn\n",
    "    plt.figure(figsize=(12, 12))\n",
Joaquin Torres's avatar
Joaquin Torres committed
324
    "    ax = sns.heatmap(matrix, mask=mask, cmap='coolwarm', center=0, annot=False, cbar=True,\n",
325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353
    "                xticklabels=feature_names, yticklabels=feature_names)\n",
    "    \n",
    "    for tick_label in ax.get_yticklabels():\n",
    "        if tick_label.get_text() in soc_var_names:\n",
    "            tick_label.set_color('purple')  # Specific social variables\n",
    "        else:\n",
    "            tick_label.set_color('green')  # Other variables\n",
    "    \n",
    "    for tick_label in ax.get_xticklabels():\n",
    "        if tick_label.get_text() in soc_var_names:\n",
    "            tick_label.set_color('purple')  # Specific social variables\n",
    "        else:\n",
    "            tick_label.set_color('green')  # Other variables\n",
    "        \n",
    "    plt.title(f'Mean Interaction Matrix - Pipeline: {method_name} - Group: {group}\\n', fontdict={'fontstyle': 'normal', 'weight': 'bold'})\n",
    "    # Create a custom legend\n",
    "    purple_patch = mpatches.Patch(color='purple', label='Social factor')\n",
    "    green_patch = mpatches.Patch(color='green', label='Individual factor')\n",
    "    plt.legend(handles=[purple_patch, green_patch], loc='upper right')\n",
    "    # Add a title to the color bar\n",
    "    cbar = ax.collections[0].colorbar\n",
    "    cbar.set_label('Normalized SHAP Interaction Value', labelpad=15, rotation=270, verticalalignment='bottom')\n",
    "    \n",
    "    plt.savefig(f'./output/plots/heatmaps_interactions/{method_name}_{group}.svg', format='svg', dpi=600)\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
354
   "execution_count": null,
355
   "metadata": {},
356
   "outputs": [],
357
   "source": [
Joaquin Torres's avatar
Joaquin Torres committed
358
    "plot_inter_heatmap(agg_shap_inter_vals_post, attribute_names, soc_var_names, method_name, 'POST')"
359 360 361 362
   ]
  },
  {
   "cell_type": "code",
363
   "execution_count": null,
364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_distance_heatmap(matrix, feature_names, soc_var_names, method_name):\n",
    "    # Create a mask for the upper triangle\n",
    "    mask = np.triu(np.ones_like(matrix, dtype=bool))\n",
    "\n",
    "    # Create the heatmap using Seaborn\n",
    "    plt.figure(figsize=(12, 12))\n",
    "    ax = sns.heatmap(matrix, mask=mask, cmap=sns.color_palette(\"light:r\", as_cmap=True), annot=False, cbar=True,\n",
    "                xticklabels=feature_names, yticklabels=feature_names)\n",
    "\n",
    "    for tick_label in ax.get_yticklabels():\n",
    "        if tick_label.get_text() in soc_var_names:\n",
    "            tick_label.set_color('purple')  # Specific social variables\n",
    "        else:\n",
    "            tick_label.set_color('green')  # Other variables\n",
    "    \n",
    "    for tick_label in ax.get_xticklabels():\n",
    "        if tick_label.get_text() in soc_var_names:\n",
    "            tick_label.set_color('purple')  # Specific social variables\n",
    "        else:\n",
    "            tick_label.set_color('green')  # Other variables\n",
    "        \n",
    "    plt.title(f'Distance Interaction Matrix between PRE and POST - Pipeline: {method_name}\\n', fontdict={'fontstyle': 'normal', 'weight': 'bold'})\n",
    "    # Create a custom legend\n",
    "    purple_patch = mpatches.Patch(color='purple', label='Social factor')\n",
    "    green_patch = mpatches.Patch(color='green', label='Individual factor')\n",
    "    plt.legend(handles=[purple_patch, green_patch], loc='upper right')\n",
    "    # Add a title to the color bar\n",
    "    cbar = ax.collections[0].colorbar\n",
    "    cbar.set_label('Abs (PRE - POST)', labelpad=15, rotation=270, verticalalignment='bottom')\n",
    "\n",
    "    plt.savefig(f'./output/plots/heatmaps_interactions/DIST_{method_name}.svg', format='svg', dpi=600)\n",
    "    \n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
404
   "execution_count": null,
405
   "metadata": {},
406
   "outputs": [],
407
   "source": [
408
    "plot_distance_heatmap(dist_matrix, attribute_names, soc_var_names, method_name)"
409
   ]
Joaquin Torres's avatar
Joaquin Torres committed
410 411 412
  },
  {
   "cell_type": "code",
413
   "execution_count": null,
Joaquin Torres's avatar
Joaquin Torres committed
414
   "metadata": {},
415
   "outputs": [],
416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451
   "source": [
    "# Define tolerance\n",
    "tolerance = np.median(dist_matrix)\n",
    "# Create a DataFrame to hold the interactions\n",
    "interactions = []\n",
    "\n",
    "# Iterate over the matrix to extract interactions above the tolerance\n",
    "for i in range(1, dist_matrix.shape[0]):\n",
    "    for j in range(i):  # Lower triangle exclduing diagonal\n",
    "        if dist_matrix[i, j] > tolerance:\n",
    "            var1 = attribute_names[i]\n",
    "            var2 = attribute_names[j]\n",
    "            if var1 in soc_var_names and var2 in soc_var_names:\n",
    "                inter_type = 'Social'\n",
    "            elif var1 in ind_var_names and var2 in ind_var_names:\n",
    "                inter_type = 'Individual'\n",
    "            else:\n",
    "                inter_type = 'Mixed'\n",
    "            interactions.append({\n",
    "                'Variable 1': var1, \n",
    "                'Variable 2': var2,\n",
    "                'SHAP Inter Variation PRE-POST': dist_matrix[i, j],\n",
    "                'Interaction Type': inter_type\n",
    "            })\n",
    "\n",
    "# Convert the list of dictionaries to a DataFrame\n",
    "interactions_df = pd.DataFrame(interactions)\n",
    "\n",
    "# Sort the DataFrame by 'Interaction Strength' in descending order\n",
    "sorted_interactions_df = interactions_df.sort_values(by='SHAP Inter Variation PRE-POST', ascending=False)\n",
    "\n",
    "# Export to Excel\n",
    "sorted_interactions_df.to_excel(f'./output/inter_variation_{method_name}.xlsx', index=False)\n",
    "\n",
    "print(\"Excel file has been created.\")"
   ]
452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
470
   "version": "3.9.5"
471 472 473 474 475
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}