shap_plots.ipynb 4.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Libraries"
   ]
  },
  {
   "cell_type": "code",
12
   "execution_count": null,
13 14 15 16 17 18
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import shap\n",
19 20
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.patches as mpatches"
21 22 23 24 25 26 27 28 29 30 31
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Data"
   ]
  },
  {
   "cell_type": "code",
32
   "execution_count": null,
33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55
   "metadata": {},
   "outputs": [],
   "source": [
    "# Retrieve attribute names in order\n",
    "attribute_names = list(np.load('../gen_train_data/data/output/attributes.npy', allow_pickle=True))\n",
    "\n",
    "# Load test data\n",
    "X_test_pre = np.load('../gen_train_data/data/output/pre/X_test_pre.npy', allow_pickle=True)\n",
    "y_test_pre = np.load('../gen_train_data/data/output/pre/y_test_pre.npy', allow_pickle=True)\n",
    "X_test_post = np.load('../gen_train_data/data/output/post/X_test_post.npy', allow_pickle=True)\n",
    "y_test_post = np.load('../gen_train_data/data/output/post/y_test_post.npy', allow_pickle=True)\n",
    "\n",
    "# Type conversion needed    \n",
    "data_dic = {\n",
    "    \"X_test_pre\": pd.DataFrame(X_test_pre, columns=attribute_names).convert_dtypes(),\n",
    "    \"y_test_pre\": y_test_pre,\n",
    "    \"X_test_post\": pd.DataFrame(X_test_post, columns=attribute_names).convert_dtypes(),\n",
    "    \"y_test_post\": y_test_post,\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
56
   "execution_count": null,
57 58 59 60 61 62 63 64 65 66 67 68 69 70
   "metadata": {},
   "outputs": [],
   "source": [
    "method_names = {\n",
    "    0: \"ORIG\",\n",
    "    1: \"ORIG_CW\",\n",
    "    2: \"OVER\",\n",
    "    3: \"UNDER\"\n",
    "}\n",
    "model_choices = {\n",
    "    \"ORIG\": \"XGB\",\n",
    "    \"ORIG_CW\": \"RF\",\n",
    "    \"OVER\": \"XGB\",\n",
    "    \"UNDER\": \"XGB\"\n",
71 72 73 74
    "}\n",
    "\n",
    "soc_var_names = np.load('../EDA/soc_vars_names.npy', allow_pickle=True)\n",
    "ind_var_names = np.load('../EDA/ind_vars_names.npy', allow_pickle=True)"
75 76 77 78 79 80
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
81 82 83 84 85 86 87 88 89
    "SHAP Plots"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
90
    "method_name = 'OVER'\n",
91
    "plt.figure(figsize=(35, 75))\n",
92 93 94 95 96 97 98
    "for i, group in enumerate(['pre', 'post']):\n",
    "            X_test = data_dic['X_test_' + group]\n",
    "            y_test = data_dic['y_test_' + group]\n",
    "            model_name = model_choices[method_name]\n",
    "            shap_vals = np.load(f'./output/shap_values/{group}_{method_name}.npy')\n",
    "            ax = plt.subplot(2,1,i+1)\n",
    "            shap.summary_plot(shap_vals, X_test, max_display=len(attribute_names), show=False)\n",
99 100
    "            plt.title(group.upper(), fontsize = 12, fontweight='bold')\n",
    "            plt.xlabel('SHAP Value')\n",
101
    "            plt.xlim(-3,5)\n",
102 103 104 105 106 107 108 109 110 111 112
    "            used_colors = {'purple': 'Social factor', 'green': 'Individual factor'}\n",
    "            for label in ax.get_yticklabels():\n",
    "                label_text = label.get_text()  # Get the text of the label\n",
    "                label.set_fontsize(8)\n",
    "                if label_text in soc_var_names:\n",
    "                        label.set_color('purple')\n",
    "                else:\n",
    "                        label.set_color('green')\n",
    "                # Create custom legend for each subplot\n",
    "                handles = [mpatches.Patch(color=color, label=label) for color, label in used_colors.items()]\n",
    "                ax.legend(handles=handles, loc='lower right', fontsize=8)\n",
113
    "            \n",
114 115
    "plt.suptitle(f'SHAP Summary Plots PRE vs POST - Pipeline: Oversampling - Model: {model_name}\\n\\n')\n",
    "plt.subplots_adjust(wspace=1)\n",
116 117 118 119
    "plt.tight_layout()\n",
    "plt.show()\n",
    "plt.savefig(f'./output/plots/shap_summary/{method_name}_{model_name}.svg', format='svg', dpi=1250)"
   ]
120 121 122 123 124 125 126 127 128
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "method_name = 'ORIG_CW'\n"
   ]
129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}