From aefbee6227c71eb12232d0197d66f89195922ef2 Mon Sep 17 00:00:00 2001 From: Joaquin Torres Bravo Date: Mon, 8 Jul 2024 12:52:25 +0200 Subject: [PATCH] Note about modifying the SHAP lib --- explainability/shap_plots.ipynb | 46 +++++++++++++++++++++++++++------ 1 file changed, 38 insertions(+), 8 deletions(-) diff --git a/explainability/shap_plots.ipynb b/explainability/shap_plots.ipynb index 6d5602a..3522c8d 100644 --- a/explainability/shap_plots.ipynb +++ b/explainability/shap_plots.ipynb @@ -17,7 +17,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -38,9 +38,22 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, - "outputs": [], + "outputs": [ + { + "ename": "FileNotFoundError", + "evalue": "[Errno 2] No such file or directory: '../EDA/output/feature_names/all_features.npy'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mFileNotFoundError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[3], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m# Retrieve attribute names in order\u001b[39;00m\n\u001b[0;32m----> 2\u001b[0m attribute_names \u001b[38;5;241m=\u001b[39m attribute_names \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mlist\u001b[39m(\u001b[43mnp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mload\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43m../EDA/output/feature_names/all_features.npy\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mallow_pickle\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m)\n\u001b[1;32m 4\u001b[0m \u001b[38;5;66;03m# Load test data\u001b[39;00m\n\u001b[1;32m 5\u001b[0m X_test_pre \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mload(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m../gen_train_data/output/pre/X_test_pre.npy\u001b[39m\u001b[38;5;124m'\u001b[39m, allow_pickle\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n", + "File \u001b[0;32m~/covid_analysis/.venv/lib/python3.9/site-packages/numpy/lib/npyio.py:427\u001b[0m, in \u001b[0;36mload\u001b[0;34m(file, mmap_mode, allow_pickle, fix_imports, encoding, max_header_size)\u001b[0m\n\u001b[1;32m 425\u001b[0m own_fid \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n\u001b[1;32m 426\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 427\u001b[0m fid \u001b[38;5;241m=\u001b[39m stack\u001b[38;5;241m.\u001b[39menter_context(\u001b[38;5;28;43mopen\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mos_fspath\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfile\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mrb\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m)\n\u001b[1;32m 428\u001b[0m own_fid \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[1;32m 430\u001b[0m \u001b[38;5;66;03m# Code to distinguish from NumPy binary files and pickles.\u001b[39;00m\n", + "\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: '../EDA/output/feature_names/all_features.npy'" + ] + } + ], "source": [ "# Retrieve attribute names in order\n", "attribute_names = attribute_names = list(np.load('../EDA/output/feature_names/all_features.npy', allow_pickle=True))\n", @@ -133,9 +146,21 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, - "outputs": [], + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'plt' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[1], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m method_name \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mORIG_CW\u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[0;32m----> 2\u001b[0m \u001b[43mplt\u001b[49m\u001b[38;5;241m.\u001b[39mfigure(figsize\u001b[38;5;241m=\u001b[39m(\u001b[38;5;241m35\u001b[39m, \u001b[38;5;241m75\u001b[39m))\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i, group \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m([\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mpre\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mpost\u001b[39m\u001b[38;5;124m'\u001b[39m]):\n\u001b[1;32m 4\u001b[0m X_test \u001b[38;5;241m=\u001b[39m data_dic[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mX_test_\u001b[39m\u001b[38;5;124m'\u001b[39m \u001b[38;5;241m+\u001b[39m group]\n", + "\u001b[0;31mNameError\u001b[0m: name 'plt' is not defined" + ] + } + ], "source": [ "method_name = 'ORIG_CW'\n", "plt.figure(figsize=(35, 75))\n", @@ -182,8 +207,13 @@ "**IMPORTANT NOTE**\n", "For the code to work as intended, the SHAP source code had to be modified:\n", "* **File**: shap/plots/_beeswarm.py\n", - "* **Line**: 591\n", - "* **Modification**: " + "* **Modification**: .venv/lib/python3.9/site-packages/shap/plots/_beeswarm.py \\\n", + "\n", + "sort_inds = np.argsort(-np.abs(shap_values.sum(1)).sum(0)) \\\n", + "**replaced by** \\\n", + "sort_inds = np.arange(39)\n", + "\n", + "The idea is to display the features in the original order instead of sorting them according to their absolute SHAP impact, as the library originally does\n" ] }, { @@ -476,7 +506,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.2" + "version": "3.9.5" } }, "nbformat": 4, -- 2.24.1