{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "Libraries" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "import numpy as np\n", "import shap\n", "import matplotlib.pyplot as plt\n", "import matplotlib.patches as mpatches\n", "import seaborn as sns" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Data" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Retrieve attribute names in order\n", "attribute_names = list(np.load('../gen_train_data/data/output/attributes.npy', allow_pickle=True))\n", "\n", "# Load test data\n", "X_test_pre = np.load('../gen_train_data/data/output/pre/X_test_pre.npy', allow_pickle=True)\n", "y_test_pre = np.load('../gen_train_data/data/output/pre/y_test_pre.npy', allow_pickle=True)\n", "X_test_post = np.load('../gen_train_data/data/output/post/X_test_post.npy', allow_pickle=True)\n", "y_test_post = np.load('../gen_train_data/data/output/post/y_test_post.npy', allow_pickle=True)\n", "\n", "# Type conversion needed \n", "data_dic = {\n", " \"X_test_pre\": pd.DataFrame(X_test_pre, columns=attribute_names).convert_dtypes(),\n", " \"y_test_pre\": y_test_pre,\n", " \"X_test_post\": pd.DataFrame(X_test_post, columns=attribute_names).convert_dtypes(),\n", " \"y_test_post\": y_test_post,\n", "}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "method_names = {\n", " 0: \"ORIG\",\n", " 1: \"ORIG_CW\",\n", " 2: \"OVER\",\n", " 3: \"UNDER\"\n", "}\n", "model_choices = {\n", " \"ORIG\": \"XGB\",\n", " \"ORIG_CW\": \"RF\",\n", " \"OVER\": \"XGB\",\n", " \"UNDER\": \"XGB\"\n", "}\n", "\n", "soc_var_names = np.load('../EDA/soc_vars_names.npy', allow_pickle=True)\n", "ind_var_names = np.load('../EDA/ind_vars_names.npy', allow_pickle=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "##### SHAP Plots" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "method_name = 'OVER'\n", "plt.figure(figsize=(35, 75))\n", "for i, group in enumerate(['pre', 'post']):\n", " X_test = data_dic['X_test_' + group]\n", " y_test = data_dic['y_test_' + group]\n", " model_name = model_choices[method_name]\n", " shap_vals = np.load(f'./output/shap_values/{group}_{method_name}.npy')\n", " ax = plt.subplot(2,1,i+1)\n", " shap.summary_plot(shap_vals, X_test, max_display=len(attribute_names), show=False)\n", " plt.title(group.upper(), fontsize = 12, fontweight='bold')\n", " plt.xlabel('SHAP Value')\n", " plt.xlim(-3,5)\n", " used_colors = {'purple': 'Social factor', 'green': 'Individual factor'}\n", " for label in ax.get_yticklabels():\n", " label_text = label.get_text() # Get the text of the label\n", " label.set_fontsize(8)\n", " if label_text in soc_var_names:\n", " label.set_color('purple')\n", " else:\n", " label.set_color('green')\n", " # Create custom legend for each subplot\n", " handles = [mpatches.Patch(color=color, label=label) for color, label in used_colors.items()]\n", " ax.legend(handles=handles, loc='lower right', fontsize=8)\n", " \n", "plt.suptitle(f'SHAP Summary Plots PRE vs POST - Pipeline: Oversampling - Model: {model_name}\\n\\n')\n", "plt.subplots_adjust(wspace=1)\n", "plt.tight_layout()\n", "plt.savefig(f'./output/plots/shap_summary/{method_name}_{model_name}.svg', format='svg', dpi=1250)\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "method_name = 'ORIG_CW'\n", "plt.figure(figsize=(35, 75))\n", "for i, group in enumerate(['pre', 'post']):\n", " X_test = data_dic['X_test_' + group]\n", " y_test = data_dic['y_test_' + group]\n", " model_name = model_choices[method_name]\n", " shap_vals = np.load(f'./output/shap_values/{group}_{method_name}.npy')\n", " shap_vals = shap_vals[:,:,1] # Select shap values for positive class\n", " ax = plt.subplot(2,1,i+1)\n", " shap.summary_plot(shap_vals, X_test, max_display=len(attribute_names), show=False)\n", " plt.title(group.upper(), fontsize = 12, fontweight='bold')\n", " plt.xlabel('SHAP Value')\n", " plt.xlim(-0.5,0.5)\n", " used_colors = {'purple': 'Social factor', 'green': 'Individual factor'}\n", " for label in ax.get_yticklabels():\n", " label_text = label.get_text() # Get the text of the label\n", " label.set_fontsize(8)\n", " if label_text in soc_var_names:\n", " label.set_color('purple')\n", " else:\n", " label.set_color('green')\n", " # Create custom legend for each subplot\n", " handles = [mpatches.Patch(color=color, label=label) for color, label in used_colors.items()]\n", " ax.legend(handles=handles, loc='lower right', fontsize=8)\n", "\n", "plt.suptitle(f'SHAP Summary Plots PRE vs POST - Pipeline: Original with Class Weight - Model: {model_name}\\n\\n')\n", "plt.subplots_adjust(wspace=1)\n", "plt.tight_layout()\n", "plt.savefig(f'./output/plots/shap_summary/{method_name}_{model_name}.svg', format='svg', dpi=1250)\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "##### SHAP Interaction Plots" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Had to modify beeswarm.py (line 591)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "method_name = 'ORIG_CW'\n", "group = 'pre'" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "X_test = data_dic['X_test_' + group]\n", "y_test = data_dic['y_test_' + group]\n", "model_name = model_choices[method_name]\n", "\n", "shap_inter_vals = np.load(f'./output/shap_inter_values/{group}_{method_name}.npy')\n", "if method_name == 'ORIG_CW':\n", " shap_inter_vals = shap_inter_vals[:,:,:,1] # Take info about positive class\n", "\n", "num_instances = shap_inter_vals.shape[0] # Dynamically get the number of instances\n", "num_features = shap_inter_vals.shape[1] # Assuming the number of features is the second dimension size\n", "# Loop over each instance and set the diagonal and lower triangle of each 39x39 matrix to NaN\n", "for i in range(num_instances):\n", " # Mask the diagonal\n", " np.fill_diagonal(shap_inter_vals[i], np.nan)\n", " # Get indices for the lower triangle, excluding the diagonal\n", " lower_triangle_indices = np.tril_indices(num_features, -1) # -1 excludes the diagonal\n", " # Set the lower triangle to NaN\n", " shap_inter_vals[i][lower_triangle_indices] = np.nan\n", "\n", "plt.figure(figsize=(10,10))\n", "shap.summary_plot(shap_inter_vals, X_test, show=False, sort=False)\n", "fig=plt.gcf()\n", "attr_names = []\n", "used_colors = {'purple': 'Social factor', 'green': 'Individual factor'}\n", "# Iterate over all axes in the figure\n", "for ax in fig.get_axes():\n", " # Customize the y-axis tick labels\n", " for label in ax.get_yticklabels():\n", " label_text = label.get_text() # Get the text of the label\n", " attr_names.append(label_text)\n", " label.set_fontsize(12)\n", " if label_text in soc_var_names:\n", " label.set_color('purple')\n", " else:\n", " label.set_color('green')\n", "\n", "# Assuming the top labels are treated as titles, let's try to modify them\n", "total_axes = len(fig.axes)\n", "for i, ax in enumerate(fig.axes):\n", " reverse_index = total_axes - 1 - i\n", " title = attr_names[reverse_index]\n", " ax.set_title(title, color='purple' if title in soc_var_names else 'green', fontsize=12, rotation=90)\n", " if method_name == 'ORIG_CW':\n", " ax.set_xlim(-0.15, 0.15) # Use same scale for pre and post\n", " elif method_name == 'OVER':\n", " ax.set_xlim(-2,2)\n", "\n", "# Create a single general legend for the whole figure\n", "handles = [mpatches.Patch(color=color, label=label) for color, label in used_colors.items()]\n", "fig.legend(handles=handles, loc='lower right', fontsize=12)\n", "\n", "plt.suptitle(f'Simplified Example SHAP Summary Interaction Plot\\n', fontsize=15, fontweight='bold', x=0.5, y=0.95, ha='center')\n", "plt.tight_layout()\n", "plt.savefig(f'./output/plots/shap_inter_summary/example.svg', format='svg', dpi=600)\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "##### SHAP Interaction Analysis" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [], "source": [ "method_name = 'OVER'" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [], "source": [ "# Normalization function\n", "def max_absolute_scale(matrix):\n", " return matrix / np.max(np.abs(matrix))" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [], "source": [ "# Load array of shap inter matrices for pre and post for the chosen method\n", "shap_inter_vals_pre = np.load(f'./output/shap_inter_values/pre_{method_name}.npy')\n", "shap_inter_vals_post = np.load(f'./output/shap_inter_values/post_{method_name}.npy')\n", "if method_name == 'ORIG_CW':\n", " shap_inter_vals_pre = shap_inter_vals_pre[:,:,:,1]\n", " shap_inter_vals_post = shap_inter_vals_post[:,:,:,1]\n", "\n", "# Normalize each matrix in each of the two arrays\n", "norm_shap_inter_vals_pre = np.array([max_absolute_scale(matrix) for matrix in shap_inter_vals_pre])\n", "norm_shap_inter_vals_post = np.array([max_absolute_scale(matrix) for matrix in shap_inter_vals_post])\n", "\n", "# Aggregate matrices in each group by calculating the mean\n", "agg_shap_inter_vals_pre = np.mean(norm_shap_inter_vals_pre, axis=0)\n", "agg_shap_inter_vals_post = np.mean(norm_shap_inter_vals_post, axis=0)\n", "\n", "# Compute the difference between the aggregated matrices and take the absolute value\n", "dist_matrix = np.abs(agg_shap_inter_vals_post - agg_shap_inter_vals_pre)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def plot_inter_heatmap(matrix, feature_names, soc_var_names, method_name, group):\n", " # Create a mask for the upper triangle\n", " mask = np.triu(np.ones_like(matrix, dtype=bool))\n", "\n", " # Create the heatmap using Seaborn\n", " plt.figure(figsize=(12, 12))\n", " ax = sns.heatmap(matrix, mask=mask, cmap='coolwarm', center=0, annot=False, cbar=True,\n", " xticklabels=feature_names, yticklabels=feature_names)\n", " \n", " for tick_label in ax.get_yticklabels():\n", " if tick_label.get_text() in soc_var_names:\n", " tick_label.set_color('purple') # Specific social variables\n", " else:\n", " tick_label.set_color('green') # Other variables\n", " \n", " for tick_label in ax.get_xticklabels():\n", " if tick_label.get_text() in soc_var_names:\n", " tick_label.set_color('purple') # Specific social variables\n", " else:\n", " tick_label.set_color('green') # Other variables\n", " \n", " plt.title(f'Mean Interaction Matrix - Pipeline: {method_name} - Group: {group}\\n', fontdict={'fontstyle': 'normal', 'weight': 'bold'})\n", " # Create a custom legend\n", " purple_patch = mpatches.Patch(color='purple', label='Social factor')\n", " green_patch = mpatches.Patch(color='green', label='Individual factor')\n", " plt.legend(handles=[purple_patch, green_patch], loc='upper right')\n", " # Add a title to the color bar\n", " cbar = ax.collections[0].colorbar\n", " cbar.set_label('Normalized SHAP Interaction Value', labelpad=15, rotation=270, verticalalignment='bottom')\n", " \n", " plt.savefig(f'./output/plots/heatmaps_interactions/{method_name}_{group}.svg', format='svg', dpi=600)\n", " plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plot_inter_heatmap(agg_shap_inter_vals_post, attribute_names, soc_var_names, method_name, 'POST')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def plot_distance_heatmap(matrix, feature_names, soc_var_names, method_name):\n", " # Create a mask for the upper triangle\n", " mask = np.triu(np.ones_like(matrix, dtype=bool))\n", "\n", " # Create the heatmap using Seaborn\n", " plt.figure(figsize=(12, 12))\n", " ax = sns.heatmap(matrix, mask=mask, cmap=sns.color_palette(\"light:r\", as_cmap=True), annot=False, cbar=True,\n", " xticklabels=feature_names, yticklabels=feature_names)\n", "\n", " for tick_label in ax.get_yticklabels():\n", " if tick_label.get_text() in soc_var_names:\n", " tick_label.set_color('purple') # Specific social variables\n", " else:\n", " tick_label.set_color('green') # Other variables\n", " \n", " for tick_label in ax.get_xticklabels():\n", " if tick_label.get_text() in soc_var_names:\n", " tick_label.set_color('purple') # Specific social variables\n", " else:\n", " tick_label.set_color('green') # Other variables\n", " \n", " plt.title(f'Distance Interaction Matrix between PRE and POST - Pipeline: {method_name}\\n', fontdict={'fontstyle': 'normal', 'weight': 'bold'})\n", " # Create a custom legend\n", " purple_patch = mpatches.Patch(color='purple', label='Social factor')\n", " green_patch = mpatches.Patch(color='green', label='Individual factor')\n", " plt.legend(handles=[purple_patch, green_patch], loc='upper right')\n", " # Add a title to the color bar\n", " cbar = ax.collections[0].colorbar\n", " cbar.set_label('Abs (PRE - POST)', labelpad=15, rotation=270, verticalalignment='bottom')\n", "\n", " plt.savefig(f'./output/plots/heatmaps_interactions/DIST_{method_name}.svg', format='svg', dpi=600)\n", " \n", " plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plot_distance_heatmap(dist_matrix, attribute_names, soc_var_names, method_name)" ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Excel file has been created.\n" ] } ], "source": [ "# Define tolerance\n", "tolerance = np.median(dist_matrix)\n", "# Create a DataFrame to hold the interactions\n", "interactions = []\n", "\n", "# Iterate over the matrix to extract interactions above the tolerance\n", "for i in range(1, dist_matrix.shape[0]):\n", " for j in range(i): # Lower triangle exclduing diagonal\n", " if dist_matrix[i, j] > tolerance:\n", " var1 = attribute_names[i]\n", " var2 = attribute_names[j]\n", " if var1 in soc_var_names and var2 in soc_var_names:\n", " inter_type = 'Social'\n", " elif var1 in ind_var_names and var2 in ind_var_names:\n", " inter_type = 'Individual'\n", " else:\n", " inter_type = 'Mixed'\n", " interactions.append({\n", " 'Variable 1': var1, \n", " 'Variable 2': var2,\n", " 'SHAP Inter Variation PRE-POST': dist_matrix[i, j],\n", " 'Interaction Type': inter_type\n", " })\n", "\n", "# Convert the list of dictionaries to a DataFrame\n", "interactions_df = pd.DataFrame(interactions)\n", "\n", "# Sort the DataFrame by 'Interaction Strength' in descending order\n", "sorted_interactions_df = interactions_df.sort_values(by='SHAP Inter Variation PRE-POST', ascending=False)\n", "\n", "# Export to Excel\n", "sorted_interactions_df.to_excel(f'./output/inter_variation_{method_name}.xlsx', index=False)\n", "\n", "print(\"Excel file has been created.\")" ] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.5" } }, "nbformat": 4, "nbformat_minor": 2 }