shap_plots.ipynb 22.6 KB
Newer Older
1 2 3 4 5 6
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
7 8 9 10 11 12 13 14 15
    "**SHAP Explainability Plots** \\\n",
    "_Author: Joaquín Torres Bravo_"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Libraries"
16 17 18 19
   ]
  },
  {
   "cell_type": "code",
20
   "execution_count": 2,
21 22 23 24 25 26
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import shap\n",
27
    "import matplotlib.pyplot as plt\n",
28 29
    "import matplotlib.patches as mpatches\n",
    "import seaborn as sns"
30 31 32 33 34 35
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
36
    "### Data"
37 38 39 40
   ]
  },
  {
   "cell_type": "code",
41
   "execution_count": 3,
42
   "metadata": {},
43 44 45 46 47 48 49 50 51 52 53 54 55 56
   "outputs": [
    {
     "ename": "FileNotFoundError",
     "evalue": "[Errno 2] No such file or directory: '../EDA/output/feature_names/all_features.npy'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mFileNotFoundError\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[3], line 2\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[38;5;66;03m# Retrieve attribute names in order\u001b[39;00m\n\u001b[0;32m----> 2\u001b[0m attribute_names \u001b[38;5;241m=\u001b[39m attribute_names \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mlist\u001b[39m(\u001b[43mnp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mload\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43m../EDA/output/feature_names/all_features.npy\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mallow_pickle\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m)\n\u001b[1;32m      4\u001b[0m \u001b[38;5;66;03m# Load test data\u001b[39;00m\n\u001b[1;32m      5\u001b[0m X_test_pre \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mload(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m../gen_train_data/output/pre/X_test_pre.npy\u001b[39m\u001b[38;5;124m'\u001b[39m, allow_pickle\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n",
      "File \u001b[0;32m~/covid_analysis/.venv/lib/python3.9/site-packages/numpy/lib/npyio.py:427\u001b[0m, in \u001b[0;36mload\u001b[0;34m(file, mmap_mode, allow_pickle, fix_imports, encoding, max_header_size)\u001b[0m\n\u001b[1;32m    425\u001b[0m     own_fid \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n\u001b[1;32m    426\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 427\u001b[0m     fid \u001b[38;5;241m=\u001b[39m stack\u001b[38;5;241m.\u001b[39menter_context(\u001b[38;5;28;43mopen\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mos_fspath\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfile\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mrb\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m)\n\u001b[1;32m    428\u001b[0m     own_fid \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[1;32m    430\u001b[0m \u001b[38;5;66;03m# Code to distinguish from NumPy binary files and pickles.\u001b[39;00m\n",
      "\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: '../EDA/output/feature_names/all_features.npy'"
     ]
    }
   ],
57 58
   "source": [
    "# Retrieve attribute names in order\n",
Joaquin Torres's avatar
Joaquin Torres committed
59
    "attribute_names = attribute_names = list(np.load('../EDA/output/feature_names/all_features.npy', allow_pickle=True))\n",
60 61
    "\n",
    "# Load test data\n",
Joaquin Torres's avatar
Joaquin Torres committed
62 63 64 65
    "X_test_pre = np.load('../gen_train_data/output/pre/X_test_pre.npy', allow_pickle=True)\n",
    "y_test_pre = np.load('../gen_train_data/output/pre/y_test_pre.npy', allow_pickle=True)\n",
    "X_test_post = np.load('../gen_train_data/output/post/X_test_post.npy', allow_pickle=True)\n",
    "y_test_post = np.load('../gen_train_data/output/post/y_test_post.npy', allow_pickle=True)\n",
66 67 68 69 70 71 72 73 74 75 76 77
    "\n",
    "# Type conversion needed    \n",
    "data_dic = {\n",
    "    \"X_test_pre\": pd.DataFrame(X_test_pre, columns=attribute_names).convert_dtypes(),\n",
    "    \"y_test_pre\": y_test_pre,\n",
    "    \"X_test_post\": pd.DataFrame(X_test_post, columns=attribute_names).convert_dtypes(),\n",
    "    \"y_test_post\": y_test_post,\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
78
   "execution_count": null,
79 80 81 82 83 84 85 86 87 88 89 90 91 92
   "metadata": {},
   "outputs": [],
   "source": [
    "method_names = {\n",
    "    0: \"ORIG\",\n",
    "    1: \"ORIG_CW\",\n",
    "    2: \"OVER\",\n",
    "    3: \"UNDER\"\n",
    "}\n",
    "model_choices = {\n",
    "    \"ORIG\": \"XGB\",\n",
    "    \"ORIG_CW\": \"RF\",\n",
    "    \"OVER\": \"XGB\",\n",
    "    \"UNDER\": \"XGB\"\n",
93 94
    "}\n",
    "\n",
95
    "# Load names of social and individual attributes\n",
Joaquin Torres's avatar
Joaquin Torres committed
96 97
    "soc_var_names = np.load('../EDA/output/feature_names/social_factors.npy', allow_pickle=True)\n",
    "ind_var_names = np.load('../EDA/output/feature_names/individual_factors.npy', allow_pickle=True)"
98 99 100 101 102 103
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
104
    "### SHAP Plots"
105 106 107 108
   ]
  },
  {
   "cell_type": "code",
109
   "execution_count": null,
110
   "metadata": {},
111
   "outputs": [],
112
   "source": [
113
    "method_name = 'OVER'\n",
114
    "\n",
115
    "plt.figure(figsize=(35, 75))\n",
116 117 118 119 120
    "for i, group in enumerate(['pre', 'post']):\n",
    "            X_test = data_dic['X_test_' + group]\n",
    "            y_test = data_dic['y_test_' + group]\n",
    "            model_name = model_choices[method_name]\n",
    "            shap_vals = np.load(f'./output/shap_values/{group}_{method_name}.npy')\n",
121 122
    "            ax = plt.subplot(2,1,i+1) # 2 rows (pre - post) 1 column\n",
    "            # show = False to modify plot before showing\n",
123
    "            shap.summary_plot(shap_vals, X_test, max_display=len(attribute_names), show=False)\n",
124 125
    "            plt.title(group.upper(), fontsize = 12, fontweight='bold')\n",
    "            plt.xlabel('SHAP Value')\n",
126
    "            plt.xlim(-3,5)\n",
127
    "            used_colors = {'purple': 'Social factor', 'green': 'Individual factor'}\n",
128
    "            # Modify color of attributes\n",
129 130 131 132 133 134 135 136 137 138
    "            for label in ax.get_yticklabels():\n",
    "                label_text = label.get_text()  # Get the text of the label\n",
    "                label.set_fontsize(8)\n",
    "                if label_text in soc_var_names:\n",
    "                        label.set_color('purple')\n",
    "                else:\n",
    "                        label.set_color('green')\n",
    "                # Create custom legend for each subplot\n",
    "                handles = [mpatches.Patch(color=color, label=label) for color, label in used_colors.items()]\n",
    "                ax.legend(handles=handles, loc='lower right', fontsize=8)\n",
139
    "            \n",
140 141
    "plt.suptitle(f'SHAP Summary Plots PRE vs POST - Pipeline: Oversampling - Model: {model_name}\\n\\n')\n",
    "plt.subplots_adjust(wspace=1)\n",
142
    "plt.tight_layout()\n",
Joaquin Torres's avatar
Joaquin Torres committed
143 144
    "plt.savefig(f'./output/plots/shap_summary/{method_name}_{model_name}.svg', format='svg', dpi=1250)\n",
    "plt.show()"
145
   ]
146 147 148
  },
  {
   "cell_type": "code",
149
   "execution_count": 1,
150
   "metadata": {},
151 152 153 154 155 156 157 158 159 160 161 162 163
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'plt' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[1], line 2\u001b[0m\n\u001b[1;32m      1\u001b[0m method_name \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mORIG_CW\u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[0;32m----> 2\u001b[0m \u001b[43mplt\u001b[49m\u001b[38;5;241m.\u001b[39mfigure(figsize\u001b[38;5;241m=\u001b[39m(\u001b[38;5;241m35\u001b[39m, \u001b[38;5;241m75\u001b[39m))\n\u001b[1;32m      3\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i, group \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m([\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mpre\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mpost\u001b[39m\u001b[38;5;124m'\u001b[39m]):\n\u001b[1;32m      4\u001b[0m             X_test \u001b[38;5;241m=\u001b[39m data_dic[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mX_test_\u001b[39m\u001b[38;5;124m'\u001b[39m \u001b[38;5;241m+\u001b[39m group]\n",
      "\u001b[0;31mNameError\u001b[0m: name 'plt' is not defined"
     ]
    }
   ],
164
   "source": [
Joaquin Torres's avatar
Joaquin Torres committed
165 166 167 168 169 170 171 172 173 174 175 176 177 178 179
    "method_name = 'ORIG_CW'\n",
    "plt.figure(figsize=(35, 75))\n",
    "for i, group in enumerate(['pre', 'post']):\n",
    "            X_test = data_dic['X_test_' + group]\n",
    "            y_test = data_dic['y_test_' + group]\n",
    "            model_name = model_choices[method_name]\n",
    "            shap_vals = np.load(f'./output/shap_values/{group}_{method_name}.npy')\n",
    "            shap_vals = shap_vals[:,:,1] # Select shap values for positive class\n",
    "            ax = plt.subplot(2,1,i+1)\n",
    "            shap.summary_plot(shap_vals, X_test, max_display=len(attribute_names), show=False)\n",
    "            plt.title(group.upper(), fontsize = 12, fontweight='bold')\n",
    "            plt.xlabel('SHAP Value')\n",
    "            plt.xlim(-0.5,0.5)\n",
    "            used_colors = {'purple': 'Social factor', 'green': 'Individual factor'}\n",
    "            for label in ax.get_yticklabels():\n",
180
    "                label_text = label.get_text()\n",
Joaquin Torres's avatar
Joaquin Torres committed
181 182 183 184 185 186 187
    "                label.set_fontsize(8)\n",
    "                if label_text in soc_var_names:\n",
    "                        label.set_color('purple')\n",
    "                else:\n",
    "                        label.set_color('green')\n",
    "                handles = [mpatches.Patch(color=color, label=label) for color, label in used_colors.items()]\n",
    "                ax.legend(handles=handles, loc='lower right', fontsize=8)\n",
188
    "\n",
Joaquin Torres's avatar
Joaquin Torres committed
189 190 191
    "plt.suptitle(f'SHAP Summary Plots PRE vs POST - Pipeline: Original with Class Weight - Model: {model_name}\\n\\n')\n",
    "plt.subplots_adjust(wspace=1)\n",
    "plt.tight_layout()\n",
Joaquin Torres's avatar
Joaquin Torres committed
192 193
    "plt.savefig(f'./output/plots/shap_summary/{method_name}_{model_name}.svg', format='svg', dpi=1250)\n",
    "plt.show()"
194
   ]
195 196 197 198 199
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
200 201 202 203 204 205 206 207 208 209
    "### SHAP Interaction Plots"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**IMPORTANT NOTE**\n",
    "For the code to work as intended, the SHAP source code had to be modified:\n",
    "* **File**: shap/plots/_beeswarm.py\n",
210 211 212 213 214 215 216
    "* **Modification**: .venv/lib/python3.9/site-packages/shap/plots/_beeswarm.py \\\n",
    "\n",
    "sort_inds = np.argsort(-np.abs(shap_values.sum(1)).sum(0)) \\\n",
    "**replaced by** \\\n",
    "sort_inds = np.arange(39)\n",
    "\n",
    "The idea is to display the features in the original order instead of sorting them according to their absolute SHAP impact, as the library originally does\n"
217 218 219 220 221 222
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
Joaquin Torres's avatar
Joaquin Torres committed
223
    "Had to modify beeswarm.py (line 591)"
224 225 226 227
   ]
  },
  {
   "cell_type": "code",
228
   "execution_count": null,
229 230 231
   "metadata": {},
   "outputs": [],
   "source": [
Joaquin Torres's avatar
Joaquin Torres committed
232
    "method_name = 'ORIG_CW'\n",
233
    "group = 'pre'"
234 235 236 237
   ]
  },
  {
   "cell_type": "code",
238
   "execution_count": null,
239
   "metadata": {},
240
   "outputs": [],
241 242 243 244 245 246
   "source": [
    "X_test = data_dic['X_test_' + group]\n",
    "y_test = data_dic['y_test_' + group]\n",
    "model_name = model_choices[method_name]\n",
    "\n",
    "shap_inter_vals = np.load(f'./output/shap_inter_values/{group}_{method_name}.npy')\n",
Joaquin Torres's avatar
Joaquin Torres committed
247 248 249 250 251 252 253 254 255 256 257 258 259
    "if method_name == 'ORIG_CW':\n",
    "      shap_inter_vals = shap_inter_vals[:,:,:,1] # Take info about positive class\n",
    "\n",
    "num_instances = shap_inter_vals.shape[0]  # Dynamically get the number of instances\n",
    "num_features = shap_inter_vals.shape[1]  # Assuming the number of features is the second dimension size\n",
    "# Loop over each instance and set the diagonal and lower triangle of each 39x39 matrix to NaN\n",
    "for i in range(num_instances):\n",
    "    # Mask the diagonal\n",
    "    np.fill_diagonal(shap_inter_vals[i], np.nan)\n",
    "    # Get indices for the lower triangle, excluding the diagonal\n",
    "    lower_triangle_indices = np.tril_indices(num_features, -1)  # -1 excludes the diagonal\n",
    "    # Set the lower triangle to NaN\n",
    "    shap_inter_vals[i][lower_triangle_indices] = np.nan\n",
260
    "\n",
261 262
    "plt.figure(figsize=(10,10))\n",
    "shap.summary_plot(shap_inter_vals, X_test, show=False, sort=False)\n",
263 264
    "fig=plt.gcf()\n",
    "attr_names = []\n",
Joaquin Torres's avatar
Joaquin Torres committed
265
    "used_colors = {'purple': 'Social factor', 'green': 'Individual factor'}\n",
266 267 268 269 270 271
    "# Iterate over all axes in the figure\n",
    "for ax in fig.get_axes():\n",
    "    # Customize the y-axis tick labels\n",
    "    for label in ax.get_yticklabels():\n",
    "        label_text = label.get_text()  # Get the text of the label\n",
    "        attr_names.append(label_text)\n",
Joaquin Torres's avatar
Joaquin Torres committed
272
    "        label.set_fontsize(12)\n",
273 274 275 276 277 278 279 280 281 282
    "        if label_text in soc_var_names:\n",
    "                label.set_color('purple')\n",
    "        else:\n",
    "                label.set_color('green')\n",
    "\n",
    "# Assuming the top labels are treated as titles, let's try to modify them\n",
    "total_axes = len(fig.axes)\n",
    "for i, ax in enumerate(fig.axes):\n",
    "        reverse_index = total_axes - 1 - i\n",
    "        title = attr_names[reverse_index]\n",
Joaquin Torres's avatar
Joaquin Torres committed
283
    "        ax.set_title(title, color='purple' if title in soc_var_names else 'green', fontsize=12, rotation=90)\n",
Joaquin Torres's avatar
Joaquin Torres committed
284
    "        if method_name == 'ORIG_CW':\n",
Joaquin Torres's avatar
Joaquin Torres committed
285
    "              ax.set_xlim(-0.15, 0.15) # Use same scale for pre and post\n",
Joaquin Torres's avatar
Joaquin Torres committed
286
    "        elif method_name == 'OVER':\n",
Joaquin Torres's avatar
Joaquin Torres committed
287
    "              ax.set_xlim(-2,2)\n",
288 289 290
    "\n",
    "# Create a single general legend for the whole figure\n",
    "handles = [mpatches.Patch(color=color, label=label) for color, label in used_colors.items()]\n",
291
    "fig.legend(handles=handles, loc='lower right', fontsize=12)\n",
292
    "\n",
293
    "plt.suptitle(f'Simplified Example SHAP Summary Interaction Plot\\n', fontsize=15, fontweight='bold', x=0.5, y=0.95, ha='center')\n",
294
    "plt.tight_layout()\n",
295
    "plt.savefig(f'./output/plots/shap_inter_summary/example.svg', format='svg', dpi=600)\n",
296 297
    "plt.show()"
   ]
298 299 300 301 302 303 304 305 306 307
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### SHAP Interaction Analysis"
   ]
  },
  {
   "cell_type": "code",
308
   "execution_count": null,
309 310 311 312 313 314 315 316
   "metadata": {},
   "outputs": [],
   "source": [
    "method_name = 'OVER'"
   ]
  },
  {
   "cell_type": "code",
317
   "execution_count": null,
318 319 320 321 322 323 324 325 326 327
   "metadata": {},
   "outputs": [],
   "source": [
    "# Normalization function\n",
    "def max_absolute_scale(matrix):\n",
    "    return matrix / np.max(np.abs(matrix))"
   ]
  },
  {
   "cell_type": "code",
328
   "execution_count": null,
329 330 331 332 333 334 335
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load array of shap inter matrices for pre and post for the chosen method\n",
    "shap_inter_vals_pre = np.load(f'./output/shap_inter_values/pre_{method_name}.npy')\n",
    "shap_inter_vals_post = np.load(f'./output/shap_inter_values/post_{method_name}.npy')\n",
    "if method_name == 'ORIG_CW':\n",
336
    "    shap_inter_vals_pre = shap_inter_vals_pre[:,:,:,1]\n",
337 338 339 340 341 342 343 344 345 346 347
    "    shap_inter_vals_post = shap_inter_vals_post[:,:,:,1]\n",
    "\n",
    "# Normalize each matrix in each of the two arrays\n",
    "norm_shap_inter_vals_pre = np.array([max_absolute_scale(matrix) for matrix in shap_inter_vals_pre])\n",
    "norm_shap_inter_vals_post = np.array([max_absolute_scale(matrix) for matrix in shap_inter_vals_post])\n",
    "\n",
    "# Aggregate matrices in each group by calculating the mean\n",
    "agg_shap_inter_vals_pre = np.mean(norm_shap_inter_vals_pre, axis=0)\n",
    "agg_shap_inter_vals_post = np.mean(norm_shap_inter_vals_post, axis=0)\n",
    "\n",
    "# Compute the difference between the aggregated matrices and take the absolute value\n",
348
    "dist_matrix = np.abs(agg_shap_inter_vals_post - agg_shap_inter_vals_pre)"
349 350 351 352
   ]
  },
  {
   "cell_type": "code",
353
   "execution_count": null,
354 355 356 357 358 359 360 361 362
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_inter_heatmap(matrix, feature_names, soc_var_names, method_name, group):\n",
    "    # Create a mask for the upper triangle\n",
    "    mask = np.triu(np.ones_like(matrix, dtype=bool))\n",
    "\n",
    "    # Create the heatmap using Seaborn\n",
    "    plt.figure(figsize=(12, 12))\n",
Joaquin Torres's avatar
Joaquin Torres committed
363
    "    ax = sns.heatmap(matrix, mask=mask, cmap='coolwarm', center=0, annot=False, cbar=True,\n",
364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392
    "                xticklabels=feature_names, yticklabels=feature_names)\n",
    "    \n",
    "    for tick_label in ax.get_yticklabels():\n",
    "        if tick_label.get_text() in soc_var_names:\n",
    "            tick_label.set_color('purple')  # Specific social variables\n",
    "        else:\n",
    "            tick_label.set_color('green')  # Other variables\n",
    "    \n",
    "    for tick_label in ax.get_xticklabels():\n",
    "        if tick_label.get_text() in soc_var_names:\n",
    "            tick_label.set_color('purple')  # Specific social variables\n",
    "        else:\n",
    "            tick_label.set_color('green')  # Other variables\n",
    "        \n",
    "    plt.title(f'Mean Interaction Matrix - Pipeline: {method_name} - Group: {group}\\n', fontdict={'fontstyle': 'normal', 'weight': 'bold'})\n",
    "    # Create a custom legend\n",
    "    purple_patch = mpatches.Patch(color='purple', label='Social factor')\n",
    "    green_patch = mpatches.Patch(color='green', label='Individual factor')\n",
    "    plt.legend(handles=[purple_patch, green_patch], loc='upper right')\n",
    "    # Add a title to the color bar\n",
    "    cbar = ax.collections[0].colorbar\n",
    "    cbar.set_label('Normalized SHAP Interaction Value', labelpad=15, rotation=270, verticalalignment='bottom')\n",
    "    \n",
    "    plt.savefig(f'./output/plots/heatmaps_interactions/{method_name}_{group}.svg', format='svg', dpi=600)\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
393
   "execution_count": null,
394
   "metadata": {},
395
   "outputs": [],
396
   "source": [
Joaquin Torres's avatar
Joaquin Torres committed
397
    "plot_inter_heatmap(agg_shap_inter_vals_post, attribute_names, soc_var_names, method_name, 'POST')"
398 399 400 401
   ]
  },
  {
   "cell_type": "code",
402
   "execution_count": null,
403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_distance_heatmap(matrix, feature_names, soc_var_names, method_name):\n",
    "    # Create a mask for the upper triangle\n",
    "    mask = np.triu(np.ones_like(matrix, dtype=bool))\n",
    "\n",
    "    # Create the heatmap using Seaborn\n",
    "    plt.figure(figsize=(12, 12))\n",
    "    ax = sns.heatmap(matrix, mask=mask, cmap=sns.color_palette(\"light:r\", as_cmap=True), annot=False, cbar=True,\n",
    "                xticklabels=feature_names, yticklabels=feature_names)\n",
    "\n",
    "    for tick_label in ax.get_yticklabels():\n",
    "        if tick_label.get_text() in soc_var_names:\n",
    "            tick_label.set_color('purple')  # Specific social variables\n",
    "        else:\n",
    "            tick_label.set_color('green')  # Other variables\n",
    "    \n",
    "    for tick_label in ax.get_xticklabels():\n",
    "        if tick_label.get_text() in soc_var_names:\n",
    "            tick_label.set_color('purple')  # Specific social variables\n",
    "        else:\n",
    "            tick_label.set_color('green')  # Other variables\n",
    "        \n",
    "    plt.title(f'Distance Interaction Matrix between PRE and POST - Pipeline: {method_name}\\n', fontdict={'fontstyle': 'normal', 'weight': 'bold'})\n",
    "    # Create a custom legend\n",
    "    purple_patch = mpatches.Patch(color='purple', label='Social factor')\n",
    "    green_patch = mpatches.Patch(color='green', label='Individual factor')\n",
    "    plt.legend(handles=[purple_patch, green_patch], loc='upper right')\n",
    "    # Add a title to the color bar\n",
    "    cbar = ax.collections[0].colorbar\n",
    "    cbar.set_label('Abs (PRE - POST)', labelpad=15, rotation=270, verticalalignment='bottom')\n",
    "\n",
    "    plt.savefig(f'./output/plots/heatmaps_interactions/DIST_{method_name}.svg', format='svg', dpi=600)\n",
    "    \n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
443
   "execution_count": null,
444
   "metadata": {},
445
   "outputs": [],
446
   "source": [
447
    "plot_distance_heatmap(dist_matrix, attribute_names, soc_var_names, method_name)"
448
   ]
Joaquin Torres's avatar
Joaquin Torres committed
449 450 451
  },
  {
   "cell_type": "code",
452
   "execution_count": null,
Joaquin Torres's avatar
Joaquin Torres committed
453
   "metadata": {},
454
   "outputs": [],
455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490
   "source": [
    "# Define tolerance\n",
    "tolerance = np.median(dist_matrix)\n",
    "# Create a DataFrame to hold the interactions\n",
    "interactions = []\n",
    "\n",
    "# Iterate over the matrix to extract interactions above the tolerance\n",
    "for i in range(1, dist_matrix.shape[0]):\n",
    "    for j in range(i):  # Lower triangle exclduing diagonal\n",
    "        if dist_matrix[i, j] > tolerance:\n",
    "            var1 = attribute_names[i]\n",
    "            var2 = attribute_names[j]\n",
    "            if var1 in soc_var_names and var2 in soc_var_names:\n",
    "                inter_type = 'Social'\n",
    "            elif var1 in ind_var_names and var2 in ind_var_names:\n",
    "                inter_type = 'Individual'\n",
    "            else:\n",
    "                inter_type = 'Mixed'\n",
    "            interactions.append({\n",
    "                'Variable 1': var1, \n",
    "                'Variable 2': var2,\n",
    "                'SHAP Inter Variation PRE-POST': dist_matrix[i, j],\n",
    "                'Interaction Type': inter_type\n",
    "            })\n",
    "\n",
    "# Convert the list of dictionaries to a DataFrame\n",
    "interactions_df = pd.DataFrame(interactions)\n",
    "\n",
    "# Sort the DataFrame by 'Interaction Strength' in descending order\n",
    "sorted_interactions_df = interactions_df.sort_values(by='SHAP Inter Variation PRE-POST', ascending=False)\n",
    "\n",
    "# Export to Excel\n",
    "sorted_interactions_df.to_excel(f'./output/inter_variation_{method_name}.xlsx', index=False)\n",
    "\n",
    "print(\"Excel file has been created.\")"
   ]
491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
509
   "version": "3.9.5"
510 511 512 513 514
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}