{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "**SHAP Explainability Plots** \\\n", "_Author: JoaquĆ­n Torres Bravo_" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Libraries" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "import numpy as np\n", "import shap\n", "import matplotlib.pyplot as plt\n", "import matplotlib.patches as mpatches # Legend\n", "import matplotlib.colors as mcolors # Colormap\n", "import seaborn as sns\n", "from scipy.stats import zscore # Normalization" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Data" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Retrieve attribute names in order\n", "attribute_names = attribute_names = list(np.load('../EDA/output/feature_names/all_features.npy', allow_pickle=True))\n", "\n", "# Load test data\n", "X_test_pre = np.load('../gen_train_data/output/pre/X_test_pre.npy', allow_pickle=True)\n", "y_test_pre = np.load('../gen_train_data/output/pre/y_test_pre.npy', allow_pickle=True)\n", "X_test_post = np.load('../gen_train_data/output/post/X_test_post.npy', allow_pickle=True)\n", "y_test_post = np.load('../gen_train_data/output/post/y_test_post.npy', allow_pickle=True)\n", "\n", "# Type conversion needed \n", "data_dic = {\n", " \"X_test_pre\": pd.DataFrame(X_test_pre, columns=attribute_names).convert_dtypes(),\n", " \"y_test_pre\": y_test_pre,\n", " \"X_test_post\": pd.DataFrame(X_test_post, columns=attribute_names).convert_dtypes(),\n", " \"y_test_post\": y_test_post,\n", "}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "method_names = {\n", " 0: \"ORIG\",\n", " 1: \"ORIG_CW\",\n", " 2: \"OVER\",\n", " 3: \"UNDER\"\n", "}\n", "model_choices = {\n", " \"ORIG\": \"XGB\",\n", " \"ORIG_CW\": \"RF\",\n", " \"OVER\": \"XGB\",\n", " \"UNDER\": \"XGB\"\n", "}\n", "\n", "# Load names of social and individual attributes\n", "soc_var_names = np.load('../EDA/output/feature_names/social_factors.npy', allow_pickle=True)\n", "ind_var_names = np.load('../EDA/output/feature_names/individual_factors.npy', allow_pickle=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### SHAP Plots" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "method_name = 'OVER'\n", "\n", "plt.figure(figsize=(35, 75))\n", "for i, group in enumerate(['pre', 'post']):\n", " X_test = data_dic['X_test_' + group]\n", " y_test = data_dic['y_test_' + group]\n", " model_name = model_choices[method_name]\n", " shap_vals = np.load(f'./output/shap_values/{group}_{method_name}.npy')\n", " ax = plt.subplot(2,1,i+1) # 2 rows (pre - post) 1 column\n", " # show = False to modify plot before showing\n", " shap.summary_plot(shap_vals, X_test, max_display=len(attribute_names), show=False)\n", " plt.title(group.upper(), fontsize = 12, fontweight='bold')\n", " plt.xlabel('SHAP Value')\n", " plt.xlim(-3,5)\n", " used_colors = {'purple': 'Social factor', 'green': 'Individual factor'}\n", " # Modify color of attributes\n", " for label in ax.get_yticklabels():\n", " label_text = label.get_text() # Get the text of the label\n", " label.set_fontsize(8)\n", " if label_text in soc_var_names:\n", " label.set_color('purple')\n", " else:\n", " label.set_color('green')\n", " # Create custom legend for each subplot\n", " handles = [mpatches.Patch(color=color, label=label) for color, label in used_colors.items()]\n", " ax.legend(handles=handles, loc='lower right', fontsize=8)\n", " \n", "plt.suptitle(f'SHAP Summary Plots PRE vs POST - Pipeline: Oversampling - Model: {model_name}\\n\\n')\n", "plt.subplots_adjust(wspace=1)\n", "plt.tight_layout()\n", "plt.savefig(f'./output/plots/shap_summary/{method_name}_{model_name}.svg', format='svg', dpi=1250)\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "method_name = 'ORIG_CW'\n", "plt.figure(figsize=(35, 75))\n", "for i, group in enumerate(['pre', 'post']):\n", " X_test = data_dic['X_test_' + group]\n", " y_test = data_dic['y_test_' + group]\n", " model_name = model_choices[method_name]\n", " shap_vals = np.load(f'./output/shap_values/{group}_{method_name}.npy')\n", " shap_vals = shap_vals[:,:,1] # Select shap values for positive class\n", " ax = plt.subplot(2,1,i+1)\n", " shap.summary_plot(shap_vals, X_test, max_display=len(attribute_names), show=False)\n", " plt.title(group.upper(), fontsize = 12, fontweight='bold')\n", " plt.xlabel('SHAP Value')\n", " plt.xlim(-0.5,0.5)\n", " used_colors = {'purple': 'Social factor', 'green': 'Individual factor'}\n", " for label in ax.get_yticklabels():\n", " label_text = label.get_text()\n", " label.set_fontsize(8)\n", " if label_text in soc_var_names:\n", " label.set_color('purple')\n", " else:\n", " label.set_color('green')\n", " handles = [mpatches.Patch(color=color, label=label) for color, label in used_colors.items()]\n", " ax.legend(handles=handles, loc='lower right', fontsize=8)\n", "\n", "plt.suptitle(f'SHAP Summary Plots PRE vs POST - Pipeline: Original with Class Weight - Model: {model_name}\\n\\n')\n", "plt.subplots_adjust(wspace=1)\n", "plt.tight_layout()\n", "plt.savefig(f'./output/plots/shap_summary/{method_name}_{model_name}.svg', format='svg', dpi=1250)\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### SHAP Interaction Plots" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**IMPORTANT NOTE** \\\n", "For the code to work as intended, the SHAP source code had to be modified: .venv/lib/python3.9/site-packages/shap/plots/_beeswarm.py \n", "\n", "sort_inds = np.argsort(-np.abs(shap_values.sum(1)).sum(0)) \\\n", "**replaced by** \\\n", "sort_inds = np.arange(39)\n", "\n", "The idea is to display the features in the original order instead of sorting them according to their absolute SHAP impact, as the library originally does\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "for method_name in ['ORIG_CW', 'OVER']:\n", " for group in ['pre', 'post']:\n", " X_test = data_dic['X_test_' + group]\n", " y_test = data_dic['y_test_' + group]\n", " model_name = model_choices[method_name]\n", "\n", " shap_inter_vals = np.load(f'./output/shap_inter_values/{group}_{method_name}.npy')\n", " if method_name == 'ORIG_CW':\n", " shap_inter_vals = shap_inter_vals[:,:,:,1] # Take info about positive class\n", "\n", " num_instances = shap_inter_vals.shape[0] # Dynamically get the number of instances\n", " num_features = shap_inter_vals.shape[1] # Assuming the number of features is the second dimension size\n", "\n", " # Loop over each instance and set the diagonal and lower triangle of each 39x39 matrix to NaN\n", " for i in range(num_instances):\n", " # Mask the diagonal\n", " np.fill_diagonal(shap_inter_vals[i], np.nan)\n", " # Get indices for the lower triangle, excluding the diagonal\n", " lower_triangle_indices = np.tril_indices(num_features, -1) # -1 excludes the diagonal\n", " # Set the lower triangle to NaN\n", " shap_inter_vals[i][lower_triangle_indices] = np.nan\n", "\n", " plt.figure(figsize=(10,10))\n", " shap.summary_plot(shap_inter_vals, X_test, show=False, sort=False, max_display=len(attribute_names))\n", " fig = plt.gcf()\n", " attr_names = []\n", " used_colors = {'purple': 'Social factor', 'green': 'Individual factor'}\n", "\n", " # Iterate over all axes in the figure\n", " for ax in fig.get_axes():\n", " # Customize the y-axis tick labels\n", " for label in ax.get_yticklabels():\n", " label_text = label.get_text() # Get the text of the label\n", " attr_names.append(label_text)\n", " label.set_fontsize(12)\n", " if label_text in soc_var_names:\n", " label.set_color('purple')\n", " else:\n", " label.set_color('green')\n", "\n", " # Assuming the top labels are treated as titles, let's try to modify them\n", " total_axes = len(fig.axes)\n", " for i, ax in enumerate(fig.axes):\n", " reverse_index = total_axes - 1 - i\n", " title = attr_names[reverse_index]\n", " ax.set_title(title, color='purple' if title in soc_var_names else 'green', fontsize=12, rotation=90)\n", " if method_name == 'ORIG_CW':\n", " ax.set_xlim(-0.15, 0.15) # Use same scale for pre and post\n", " elif method_name == 'OVER':\n", " ax.set_xlim(-2, 2)\n", "\n", " # Create a single general legend for the whole figure\n", " handles = [mpatches.Patch(color=color, label=label) for color, label in used_colors.items()]\n", " fig.legend(handles=handles, loc='lower right', fontsize=12)\n", "\n", " # plt.suptitle(f'Simplified Example SHAP Summary Interaction Plot\\n', fontsize=15, fontweight='bold', x=0.5, y=0.95, ha='center')\n", " plt.suptitle(f'SHAP Summary Interaction Plot - {method_name} - {str.upper(group)}\\n', fontsize=20, fontweight='bold') #, x=0.5, y=0.95, ha='center'\n", " plt.tight_layout()\n", " plt.savefig(f'./output/plots/shap_inter_summary/{str.upper(group)}_{method_name}_{model_name}.svg', format='svg', dpi=700)\n", " # plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### SHAP Interaction Analysis" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Distance Heatmaps" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "method_name = 'OVER'" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Define the z-score normalization function\n", "def zscore_normalize(matrix):\n", " return zscore(matrix, axis=None) # Normalize across the entire matrix" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Load array of shap inter matrices for pre and post for the chosen method\n", "shap_inter_vals_pre = np.load(f'./output/shap_inter_values/pre_{method_name}.npy')\n", "shap_inter_vals_post = np.load(f'./output/shap_inter_values/post_{method_name}.npy')\n", "if method_name == 'ORIG_CW':\n", " shap_inter_vals_pre = shap_inter_vals_pre[:,:,:,1] # Take info about positive class\n", " shap_inter_vals_post = shap_inter_vals_post[:,:,:,1]\n", "\n", "# Normalize each matrix in each of the two arrays\n", "norm_shap_inter_vals_pre = np.array([zscore_normalize(matrix) for matrix in shap_inter_vals_pre])\n", "norm_shap_inter_vals_post = np.array([zscore_normalize(matrix) for matrix in shap_inter_vals_post])\n", "\n", "# Aggregate matrices in each group by calculating the mean\n", "agg_shap_inter_vals_pre = np.mean(norm_shap_inter_vals_pre, axis=0)\n", "agg_shap_inter_vals_post = np.mean(norm_shap_inter_vals_post, axis=0)\n", "\n", "# Compute the difference between the aggregated matrices\n", "dist_matrix = agg_shap_inter_vals_post - agg_shap_inter_vals_pre" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Color map\n", "colors = [(1, 0, 0), (1, 1, 1), (0, 1, 0)] # Red, White, Green\n", "n_bins = 100 # Discretize the colormap into 100 values\n", "cmap_name = 'custom_red_white_dark_green'\n", "cmap = mcolors.LinearSegmentedColormap.from_list(cmap_name, colors, N=n_bins)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def plot_distance_heatmap(matrix, feature_names, soc_var_names, method_name):\n", " # Create a mask for the upper triangle\n", " mask = np.triu(np.ones_like(matrix, dtype=bool))\n", "\n", " # Create the heatmap using Seaborn\n", " plt.figure(figsize=(12, 12))\n", " ax = sns.heatmap(matrix, mask=mask, cmap=cmap, center=0, annot=False, cbar=True,\n", " xticklabels=feature_names, yticklabels=feature_names)\n", "\n", " for tick_label in ax.get_yticklabels():\n", " if tick_label.get_text() in soc_var_names:\n", " tick_label.set_color('purple') # Specific social variables\n", " else:\n", " tick_label.set_color('green') # Other variables\n", " \n", " for tick_label in ax.get_xticklabels():\n", " if tick_label.get_text() in soc_var_names:\n", " tick_label.set_color('purple') # Specific social variables\n", " else:\n", " tick_label.set_color('green') # Other variables\n", " \n", " plt.title(f'Distance Interaction Matrix between PRE and POST - Pipeline: {method_name}\\n', fontdict={'fontstyle': 'normal', 'weight': 'bold'})\n", " # Create a custom legend\n", " purple_patch = mpatches.Patch(color='purple', label='Social factor')\n", " green_patch = mpatches.Patch(color='green', label='Individual factor')\n", " plt.legend(handles=[purple_patch, green_patch], loc='upper right')\n", " # Add a title to the color bar\n", " cbar = ax.collections[0].colorbar\n", " cbar.set_label('Interaction POST - Interaction PRE', labelpad=15, rotation=270, verticalalignment='bottom')\n", "\n", " plt.savefig(f'./output/plots/heatmaps_interactions/DIST_{method_name}.svg', format='svg', dpi=600)\n", " \n", " plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plot_distance_heatmap(dist_matrix, attribute_names, soc_var_names, method_name)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Excel Differences Sorted" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Define tolerance\n", "tolerance = np.median(np.abs(dist_matrix)) # Use the median of the absolute values for tolerance\n", "\n", "# Create a DataFrame to hold the interactions\n", "interactions = []\n", "\n", "# Iterate over the matrix to extract interactions above the tolerance\n", "for i in range(1, dist_matrix.shape[0]):\n", " for j in range(i): # Lower triangle exclduing diagonal\n", " if abs(dist_matrix[i, j]) > tolerance: \n", " var1 = attribute_names[i]\n", " var2 = attribute_names[j]\n", " if var1 in soc_var_names and var2 in soc_var_names:\n", " inter_type = 'Social'\n", " elif var1 in ind_var_names and var2 in ind_var_names:\n", " inter_type = 'Individual'\n", " else:\n", " inter_type = 'Mixed'\n", " interactions.append({\n", " 'Variable 1': var1, \n", " 'Variable 2': var2,\n", " 'SHAP Inter Variation PRE-POST': dist_matrix[i, j],\n", " 'Interaction Type': inter_type\n", " })\n", "\n", "# Convert the list of dictionaries to a DataFrame\n", "interactions_df = pd.DataFrame(interactions)\n", "\n", "sorted_interactions_df = interactions_df.reindex(\n", " interactions_df['SHAP Inter Variation PRE-POST'].abs().sort_values(ascending=False).index)\n", "\n", "# Export to Excel\n", "sorted_interactions_df.to_excel(f'./output/inter_variation_{method_name}.xlsx', index=False)\n", "\n", "print(\"Excel file has been created.\")" ] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.5" } }, "nbformat": 4, "nbformat_minor": 2 }