shap_plots.ipynb 17.7 KB
Newer Older
1 2 3 4 5 6
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
7 8 9 10 11 12 13 14 15
    "**SHAP Explainability Plots** \\\n",
    "_Author: Joaquín Torres Bravo_"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Libraries"
16 17 18 19
   ]
  },
  {
   "cell_type": "code",
20
   "execution_count": null,
21 22 23 24 25 26
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import shap\n",
27
    "import matplotlib.pyplot as plt\n",
Joaquin Torres's avatar
Joaquin Torres committed
28
    "import matplotlib.patches as mpatches # Legend\n",
29
    "import matplotlib.colors as mcolors # Colormap\n",
Joaquin Torres's avatar
Joaquin Torres committed
30 31
    "import seaborn as sns\n",
    "from scipy.stats import zscore # Normalization"
32 33 34 35 36 37
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
38
    "### Data"
39 40 41 42
   ]
  },
  {
   "cell_type": "code",
Joaquin Torres's avatar
Joaquin Torres committed
43
   "execution_count": null,
44
   "metadata": {},
Joaquin Torres's avatar
Joaquin Torres committed
45
   "outputs": [],
46 47
   "source": [
    "# Retrieve attribute names in order\n",
Joaquin Torres's avatar
Joaquin Torres committed
48
    "attribute_names = attribute_names = list(np.load('../01-EDA/results/feature_names/all_features.npy', allow_pickle=True))\n",
49 50
    "\n",
    "# Load test data\n",
Joaquin Torres's avatar
Joaquin Torres committed
51 52 53 54
    "X_test_pre = np.load('../02-training_data_generation/results/pre/X_test_pre.npy', allow_pickle=True)\n",
    "y_test_pre = np.load('../02-training_data_generation/results/pre/y_test_pre.npy', allow_pickle=True)\n",
    "X_test_post = np.load('../02-training_data_generation/results/post/X_test_post.npy', allow_pickle=True)\n",
    "y_test_post = np.load('../02-training_data_generation/results/post/y_test_post.npy', allow_pickle=True)\n",
55 56 57 58 59 60 61 62 63 64 65 66
    "\n",
    "# Type conversion needed    \n",
    "data_dic = {\n",
    "    \"X_test_pre\": pd.DataFrame(X_test_pre, columns=attribute_names).convert_dtypes(),\n",
    "    \"y_test_pre\": y_test_pre,\n",
    "    \"X_test_post\": pd.DataFrame(X_test_post, columns=attribute_names).convert_dtypes(),\n",
    "    \"y_test_post\": y_test_post,\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
Joaquin Torres's avatar
Joaquin Torres committed
67
   "execution_count": null,
68 69 70 71 72 73 74 75 76 77 78 79 80 81
   "metadata": {},
   "outputs": [],
   "source": [
    "method_names = {\n",
    "    0: \"ORIG\",\n",
    "    1: \"ORIG_CW\",\n",
    "    2: \"OVER\",\n",
    "    3: \"UNDER\"\n",
    "}\n",
    "model_choices = {\n",
    "    \"ORIG\": \"XGB\",\n",
    "    \"ORIG_CW\": \"RF\",\n",
    "    \"OVER\": \"XGB\",\n",
    "    \"UNDER\": \"XGB\"\n",
82 83
    "}\n",
    "\n",
84
    "# Load names of social and individual attributes\n",
Joaquin Torres's avatar
Joaquin Torres committed
85 86
    "soc_var_names = np.load('../01-EDA/results/feature_names/social_factors.npy', allow_pickle=True)\n",
    "ind_var_names = np.load('../01-EDA/results/feature_names/individual_factors.npy', allow_pickle=True)"
87 88 89 90 91 92
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
93
    "### SHAP Plots"
94 95 96 97
   ]
  },
  {
   "cell_type": "code",
98
   "execution_count": null,
99
   "metadata": {},
100
   "outputs": [],
101
   "source": [
102
    "method_name = 'OVER'\n",
103
    "\n",
104
    "plt.figure(figsize=(35, 75))\n",
105 106 107 108
    "for i, group in enumerate(['pre', 'post']):\n",
    "            X_test = data_dic['X_test_' + group]\n",
    "            y_test = data_dic['y_test_' + group]\n",
    "            model_name = model_choices[method_name]\n",
Joaquin Torres's avatar
Joaquin Torres committed
109
    "            shap_vals = np.load(f'./results/shap_values/{group}_{method_name}.npy')\n",
110 111
    "            ax = plt.subplot(2,1,i+1) # 2 rows (pre - post) 1 column\n",
    "            # show = False to modify plot before showing\n",
112
    "            shap.summary_plot(shap_vals, X_test, max_display=len(attribute_names), show=False)\n",
113 114
    "            plt.title(group.upper(), fontsize = 12, fontweight='bold')\n",
    "            plt.xlabel('SHAP Value')\n",
115
    "            plt.xlim(-3,5)\n",
116
    "            used_colors = {'purple': 'Social factor', 'green': 'Individual factor'}\n",
117
    "            # Modify color of attributes\n",
118 119 120 121 122 123 124 125 126 127
    "            for label in ax.get_yticklabels():\n",
    "                label_text = label.get_text()  # Get the text of the label\n",
    "                label.set_fontsize(8)\n",
    "                if label_text in soc_var_names:\n",
    "                        label.set_color('purple')\n",
    "                else:\n",
    "                        label.set_color('green')\n",
    "                # Create custom legend for each subplot\n",
    "                handles = [mpatches.Patch(color=color, label=label) for color, label in used_colors.items()]\n",
    "                ax.legend(handles=handles, loc='lower right', fontsize=8)\n",
128
    "            \n",
129 130
    "plt.suptitle(f'SHAP Summary Plots PRE vs POST - Pipeline: Oversampling - Model: {model_name}\\n\\n')\n",
    "plt.subplots_adjust(wspace=1)\n",
131
    "plt.tight_layout()\n",
Joaquin Torres's avatar
Joaquin Torres committed
132
    "plt.savefig(f'./results/plots/shap_summary/{method_name}_{model_name}.svg', format='svg', dpi=1250)\n",
Joaquin Torres's avatar
Joaquin Torres committed
133
    "plt.show()"
134
   ]
135 136 137
  },
  {
   "cell_type": "code",
Joaquin Torres's avatar
Joaquin Torres committed
138
   "execution_count": null,
139
   "metadata": {},
Joaquin Torres's avatar
Joaquin Torres committed
140
   "outputs": [],
141
   "source": [
Joaquin Torres's avatar
Joaquin Torres committed
142 143 144 145 146 147
    "method_name = 'ORIG_CW'\n",
    "plt.figure(figsize=(35, 75))\n",
    "for i, group in enumerate(['pre', 'post']):\n",
    "            X_test = data_dic['X_test_' + group]\n",
    "            y_test = data_dic['y_test_' + group]\n",
    "            model_name = model_choices[method_name]\n",
Joaquin Torres's avatar
Joaquin Torres committed
148
    "            shap_vals = np.load(f'./results/shap_values/{group}_{method_name}.npy')\n",
Joaquin Torres's avatar
Joaquin Torres committed
149 150 151 152 153 154 155 156
    "            shap_vals = shap_vals[:,:,1] # Select shap values for positive class\n",
    "            ax = plt.subplot(2,1,i+1)\n",
    "            shap.summary_plot(shap_vals, X_test, max_display=len(attribute_names), show=False)\n",
    "            plt.title(group.upper(), fontsize = 12, fontweight='bold')\n",
    "            plt.xlabel('SHAP Value')\n",
    "            plt.xlim(-0.5,0.5)\n",
    "            used_colors = {'purple': 'Social factor', 'green': 'Individual factor'}\n",
    "            for label in ax.get_yticklabels():\n",
157
    "                label_text = label.get_text()\n",
Joaquin Torres's avatar
Joaquin Torres committed
158 159 160 161 162 163 164
    "                label.set_fontsize(8)\n",
    "                if label_text in soc_var_names:\n",
    "                        label.set_color('purple')\n",
    "                else:\n",
    "                        label.set_color('green')\n",
    "                handles = [mpatches.Patch(color=color, label=label) for color, label in used_colors.items()]\n",
    "                ax.legend(handles=handles, loc='lower right', fontsize=8)\n",
165
    "\n",
Joaquin Torres's avatar
Joaquin Torres committed
166 167 168
    "plt.suptitle(f'SHAP Summary Plots PRE vs POST - Pipeline: Original with Class Weight - Model: {model_name}\\n\\n')\n",
    "plt.subplots_adjust(wspace=1)\n",
    "plt.tight_layout()\n",
Joaquin Torres's avatar
Joaquin Torres committed
169
    "plt.savefig(f'./results/plots/shap_summary/{method_name}_{model_name}.svg', format='svg', dpi=1250)\n",
Joaquin Torres's avatar
Joaquin Torres committed
170
    "plt.show()"
171
   ]
172 173 174 175 176
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
177 178 179 180 181 182 183
    "### SHAP Interaction Plots"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
Joaquin Torres's avatar
Joaquin Torres committed
184 185
    "**IMPORTANT NOTE** \\\n",
    "For the code to work as intended, the SHAP source code had to be modified: .venv/lib/python3.9/site-packages/shap/plots/_beeswarm.py \n",
186 187 188 189 190 191
    "\n",
    "sort_inds = np.argsort(-np.abs(shap_values.sum(1)).sum(0)) \\\n",
    "**replaced by** \\\n",
    "sort_inds = np.arange(39)\n",
    "\n",
    "The idea is to display the features in the original order instead of sorting them according to their absolute SHAP impact, as the library originally does\n"
192 193 194 195
   ]
  },
  {
   "cell_type": "code",
Joaquin Torres's avatar
Joaquin Torres committed
196
   "execution_count": null,
197
   "metadata": {},
Joaquin Torres's avatar
Joaquin Torres committed
198
   "outputs": [],
199
   "source": [
200 201
    "for method_name in ['ORIG_CW', 'OVER']:\n",
    "    for group in ['pre', 'post']:\n",
202 203 204
    "        X_test = data_dic['X_test_' + group]\n",
    "        y_test = data_dic['y_test_' + group]\n",
    "        model_name = model_choices[method_name]\n",
205
    "\n",
Joaquin Torres's avatar
Joaquin Torres committed
206
    "        shap_inter_vals = np.load(f'./results/shap_inter_values/{group}_{method_name}.npy')\n",
207 208
    "        if method_name == 'ORIG_CW':\n",
    "            shap_inter_vals = shap_inter_vals[:,:,:,1]  # Take info about positive class\n",
Joaquin Torres's avatar
Joaquin Torres committed
209
    "\n",
210 211
    "        num_instances = shap_inter_vals.shape[0]  # Dynamically get the number of instances\n",
    "        num_features = shap_inter_vals.shape[1]  # Assuming the number of features is the second dimension size\n",
212
    "\n",
213 214 215 216 217 218 219 220
    "        # Loop over each instance and set the diagonal and lower triangle of each 39x39 matrix to NaN\n",
    "        for i in range(num_instances):\n",
    "            # Mask the diagonal\n",
    "            np.fill_diagonal(shap_inter_vals[i], np.nan)\n",
    "            # Get indices for the lower triangle, excluding the diagonal\n",
    "            lower_triangle_indices = np.tril_indices(num_features, -1)  # -1 excludes the diagonal\n",
    "            # Set the lower triangle to NaN\n",
    "            shap_inter_vals[i][lower_triangle_indices] = np.nan\n",
221
    "\n",
222 223 224 225 226
    "        plt.figure(figsize=(10,10))\n",
    "        shap.summary_plot(shap_inter_vals, X_test, show=False, sort=False, max_display=len(attribute_names))\n",
    "        fig = plt.gcf()\n",
    "        attr_names = []\n",
    "        used_colors = {'purple': 'Social factor', 'green': 'Individual factor'}\n",
227
    "\n",
228 229 230 231 232 233 234 235 236 237 238
    "        # Iterate over all axes in the figure\n",
    "        for ax in fig.get_axes():\n",
    "            # Customize the y-axis tick labels\n",
    "            for label in ax.get_yticklabels():\n",
    "                label_text = label.get_text()  # Get the text of the label\n",
    "                attr_names.append(label_text)\n",
    "                label.set_fontsize(12)\n",
    "                if label_text in soc_var_names:\n",
    "                    label.set_color('purple')\n",
    "                else:\n",
    "                    label.set_color('green')\n",
239
    "\n",
240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257
    "        # Assuming the top labels are treated as titles, let's try to modify them\n",
    "        total_axes = len(fig.axes)\n",
    "        for i, ax in enumerate(fig.axes):\n",
    "            reverse_index = total_axes - 1 - i\n",
    "            title = attr_names[reverse_index]\n",
    "            ax.set_title(title, color='purple' if title in soc_var_names else 'green', fontsize=12, rotation=90)\n",
    "            if method_name == 'ORIG_CW':\n",
    "                ax.set_xlim(-0.15, 0.15)  # Use same scale for pre and post\n",
    "            elif method_name == 'OVER':\n",
    "                ax.set_xlim(-2, 2)\n",
    "\n",
    "        # Create a single general legend for the whole figure\n",
    "        handles = [mpatches.Patch(color=color, label=label) for color, label in used_colors.items()]\n",
    "        fig.legend(handles=handles, loc='lower right', fontsize=12)\n",
    "\n",
    "        # plt.suptitle(f'Simplified Example SHAP Summary Interaction Plot\\n', fontsize=15, fontweight='bold', x=0.5, y=0.95, ha='center')\n",
    "        plt.suptitle(f'SHAP Summary Interaction Plot - {method_name} - {str.upper(group)}\\n', fontsize=20, fontweight='bold') #, x=0.5, y=0.95, ha='center'\n",
    "        plt.tight_layout()\n",
Joaquin Torres's avatar
Joaquin Torres committed
258
    "        plt.savefig(f'./results/plots/shap_inter_summary/{str.upper(group)}_{method_name}_{model_name}.svg', format='svg', dpi=700)\n",
259
    "        # plt.show()"
260
   ]
261 262 263 264 265
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
266 267 268 269 270 271 272 273
    "### SHAP Interaction Analysis"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Distance Heatmaps"
274 275 276 277
   ]
  },
  {
   "cell_type": "code",
278
   "execution_count": null,
279 280 281
   "metadata": {},
   "outputs": [],
   "source": [
Joaquin Torres's avatar
Ready  
Joaquin Torres committed
282
    "method_name = 'OVER'"
283 284 285 286
   ]
  },
  {
   "cell_type": "code",
287
   "execution_count": null,
288 289 290
   "metadata": {},
   "outputs": [],
   "source": [
Joaquin Torres's avatar
Joaquin Torres committed
291 292 293
    "# Define the z-score normalization function\n",
    "def zscore_normalize(matrix):\n",
    "    return zscore(matrix, axis=None)  # Normalize across the entire matrix"
294 295 296 297
   ]
  },
  {
   "cell_type": "code",
298
   "execution_count": null,
299 300 301 302
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load array of shap inter matrices for pre and post for the chosen method\n",
Joaquin Torres's avatar
Joaquin Torres committed
303 304
    "shap_inter_vals_pre = np.load(f'./results/shap_inter_values/pre_{method_name}.npy')\n",
    "shap_inter_vals_post = np.load(f'./results/shap_inter_values/post_{method_name}.npy')\n",
305
    "if method_name == 'ORIG_CW':\n",
306
    "    shap_inter_vals_pre = shap_inter_vals_pre[:,:,:,1] # Take info about positive class\n",
307 308 309
    "    shap_inter_vals_post = shap_inter_vals_post[:,:,:,1]\n",
    "\n",
    "# Normalize each matrix in each of the two arrays\n",
Joaquin Torres's avatar
Joaquin Torres committed
310 311
    "norm_shap_inter_vals_pre = np.array([zscore_normalize(matrix) for matrix in shap_inter_vals_pre])\n",
    "norm_shap_inter_vals_post = np.array([zscore_normalize(matrix) for matrix in shap_inter_vals_post])\n",
312 313 314 315 316
    "\n",
    "# Aggregate matrices in each group by calculating the mean\n",
    "agg_shap_inter_vals_pre = np.mean(norm_shap_inter_vals_pre, axis=0)\n",
    "agg_shap_inter_vals_post = np.mean(norm_shap_inter_vals_post, axis=0)\n",
    "\n",
317 318
    "# Compute the difference between the aggregated matrices\n",
    "dist_matrix = agg_shap_inter_vals_post - agg_shap_inter_vals_pre"
319 320 321 322
   ]
  },
  {
   "cell_type": "code",
323
   "execution_count": null,
324 325 326
   "metadata": {},
   "outputs": [],
   "source": [
327 328 329 330 331
    "# Color map\n",
    "colors = [(1, 0, 0), (1, 1, 1), (0, 1, 0)]  # Red, White, Green\n",
    "n_bins = 100  # Discretize the colormap into 100 values\n",
    "cmap_name = 'custom_red_white_dark_green'\n",
    "cmap = mcolors.LinearSegmentedColormap.from_list(cmap_name, colors, N=n_bins)"
332 333 334 335
   ]
  },
  {
   "cell_type": "code",
336
   "execution_count": null,
337 338 339 340 341 342 343 344 345
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_distance_heatmap(matrix, feature_names, soc_var_names, method_name):\n",
    "    # Create a mask for the upper triangle\n",
    "    mask = np.triu(np.ones_like(matrix, dtype=bool))\n",
    "\n",
    "    # Create the heatmap using Seaborn\n",
    "    plt.figure(figsize=(12, 12))\n",
346
    "    ax = sns.heatmap(matrix, mask=mask, cmap=cmap, center=0, annot=False, cbar=True,\n",
347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367
    "                xticklabels=feature_names, yticklabels=feature_names)\n",
    "\n",
    "    for tick_label in ax.get_yticklabels():\n",
    "        if tick_label.get_text() in soc_var_names:\n",
    "            tick_label.set_color('purple')  # Specific social variables\n",
    "        else:\n",
    "            tick_label.set_color('green')  # Other variables\n",
    "    \n",
    "    for tick_label in ax.get_xticklabels():\n",
    "        if tick_label.get_text() in soc_var_names:\n",
    "            tick_label.set_color('purple')  # Specific social variables\n",
    "        else:\n",
    "            tick_label.set_color('green')  # Other variables\n",
    "        \n",
    "    plt.title(f'Distance Interaction Matrix between PRE and POST - Pipeline: {method_name}\\n', fontdict={'fontstyle': 'normal', 'weight': 'bold'})\n",
    "    # Create a custom legend\n",
    "    purple_patch = mpatches.Patch(color='purple', label='Social factor')\n",
    "    green_patch = mpatches.Patch(color='green', label='Individual factor')\n",
    "    plt.legend(handles=[purple_patch, green_patch], loc='upper right')\n",
    "    # Add a title to the color bar\n",
    "    cbar = ax.collections[0].colorbar\n",
368
    "    cbar.set_label('Interaction POST - Interaction PRE', labelpad=15, rotation=270, verticalalignment='bottom')\n",
369
    "\n",
Joaquin Torres's avatar
Joaquin Torres committed
370
    "    plt.savefig(f'./results/plots/heatmaps_interactions/DIST_{method_name}.svg', format='svg', dpi=600)\n",
371 372 373 374 375 376
    "    \n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
377
   "execution_count": null,
378
   "metadata": {},
379
   "outputs": [],
380
   "source": [
381
    "plot_distance_heatmap(dist_matrix, attribute_names, soc_var_names, method_name)"
382
   ]
Joaquin Torres's avatar
Joaquin Torres committed
383
  },
384 385 386 387 388 389 390
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Excel Differences Sorted"
   ]
  },
Joaquin Torres's avatar
Joaquin Torres committed
391 392
  {
   "cell_type": "code",
Joaquin Torres's avatar
Ready  
Joaquin Torres committed
393
   "execution_count": null,
Joaquin Torres's avatar
Joaquin Torres committed
394
   "metadata": {},
Joaquin Torres's avatar
Ready  
Joaquin Torres committed
395
   "outputs": [],
396 397
   "source": [
    "# Define tolerance\n",
398 399
    "tolerance = np.median(np.abs(dist_matrix))  # Use the median of the absolute values for tolerance\n",
    "\n",
400 401 402 403 404 405
    "# Create a DataFrame to hold the interactions\n",
    "interactions = []\n",
    "\n",
    "# Iterate over the matrix to extract interactions above the tolerance\n",
    "for i in range(1, dist_matrix.shape[0]):\n",
    "    for j in range(i):  # Lower triangle exclduing diagonal\n",
406
    "        if abs(dist_matrix[i, j]) > tolerance: \n",
407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424
    "            var1 = attribute_names[i]\n",
    "            var2 = attribute_names[j]\n",
    "            if var1 in soc_var_names and var2 in soc_var_names:\n",
    "                inter_type = 'Social'\n",
    "            elif var1 in ind_var_names and var2 in ind_var_names:\n",
    "                inter_type = 'Individual'\n",
    "            else:\n",
    "                inter_type = 'Mixed'\n",
    "            interactions.append({\n",
    "                'Variable 1': var1, \n",
    "                'Variable 2': var2,\n",
    "                'SHAP Inter Variation PRE-POST': dist_matrix[i, j],\n",
    "                'Interaction Type': inter_type\n",
    "            })\n",
    "\n",
    "# Convert the list of dictionaries to a DataFrame\n",
    "interactions_df = pd.DataFrame(interactions)\n",
    "\n",
425 426
    "sorted_interactions_df = interactions_df.reindex(\n",
    "    interactions_df['SHAP Inter Variation PRE-POST'].abs().sort_values(ascending=False).index)\n",
427 428
    "\n",
    "# Export to Excel\n",
429
    "sorted_interactions_df.to_excel(f'./results/pre_post_inter_diff/inter_variation_{method_name}.xlsx', index=False)\n",
430 431 432
    "\n",
    "print(\"Excel file has been created.\")"
   ]
433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
Joaquin Torres's avatar
Joaquin Torres committed
451
   "version": "3.12.2"
452 453 454 455 456
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}