{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "### EDA" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Libraries" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "import seaborn as sns\n", "import matplotlib.pyplot as plt\n", "import numpy as np" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Preparing Data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "##### Reading and filtering" ] }, { "cell_type": "code", "execution_count": 52, "metadata": {}, "outputs": [], "source": [ "bd_all = pd.read_spss('17_abril.sav')\n", "\n", "# Filter the dataset to work only with alcohol patients\n", "bd = bd_all[bd_all['Alcohol_DxCIE'] == 'Sí']\n", "\n", "# Filter the dataset to work only with 'Situacion_tratamiento' == 'Abandono' or 'Alta'\n", "bd = bd[(bd['Situacion_tratamiento'] == 'Abandono') | (bd['Situacion_tratamiento'] == 'Alta terapéutica')]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "##### Dealing with unknown values" ] }, { "cell_type": "code", "execution_count": 53, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['Live with families or friends' 'live alone' 'live in institutions' '9.0']\n", "['Live with families or friends' 'live alone' 'live in institutions'\n", " 'Unknown']\n" ] } ], "source": [ "# 9.0 represents unknown according to Variables.docx -> replace it\n", "print(bd['Social_inclusion'].unique())\n", "bd['Social_inclusion'] = bd['Social_inclusion'].replace('9.0', 'Unknown')\n", "print(bd['Social_inclusion'].unique())" ] }, { "cell_type": "code", "execution_count": 60, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "33538\n" ] } ], "source": [ "print(len(bd['Social_inclusion'] == 'Unknown'))" ] }, { "cell_type": "code", "execution_count": 54, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['No alterations (first exposure at 11 or more years)'\n", " 'Alterations (first exposure before 11 years old)' '9']\n", "['No alterations (first exposure at 11 or more years)'\n", " 'Alterations (first exposure before 11 years old)' 'Unknown']\n" ] } ], "source": [ "print(bd['Alterations_early_childhood_develop'].unique())\n", "bd['Alterations_early_childhood_develop'] = bd['Alterations_early_childhood_develop'].replace('9', 'Unknown')\n", "print(bd['Alterations_early_childhood_develop'].unique())" ] }, { "cell_type": "code", "execution_count": 61, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "33538\n" ] } ], "source": [ "print(len(bd['Alterations_early_childhood_develop'] == 'Unknown'))" ] }, { "cell_type": "code", "execution_count": 55, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[NaN, 'Yes', 'No']\n", "Categories (3, object): [99.0, 'No', 'Yes']\n", "[NaN, 'Yes', 'No']\n", "Categories (3, object): ['Unknown', 'No', 'Yes']\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "C:\\Users\\Joaquín Torres\\AppData\\Local\\Temp\\ipykernel_2068\\1545592222.py:2: FutureWarning: The behavior of Series.replace (and DataFrame.replace) with CategoricalDtype is deprecated. In a future version, replace will only be used for cases that preserve the categories. To change the categories, use ser.cat.rename_categories instead.\n", " bd['Risk_stigma'] = bd['Risk_stigma'].replace(99.0, 'Unknown')\n" ] } ], "source": [ "print(bd['Risk_stigma'].unique())\n", "bd['Risk_stigma'] = bd['Risk_stigma'].replace(99.0, 'Unknown')\n", "print(bd['Risk_stigma'].unique())" ] }, { "cell_type": "code", "execution_count": 62, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "33538\n" ] } ], "source": [ "print(len(bd['Risk_stigma'] == 'Unknown'))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(bd['NumHijos'].unique())\n", "bd['NumHijos'] = bd['NumHijos'].replace(99.0, 'Unknown')\n", "print(bd['NumHijos'].unique())" ] }, { "cell_type": "code", "execution_count": 63, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "33538\n" ] } ], "source": [ "print(len(bd['NumHijos'] == 99.0))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "##### Defining sets of patients" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Pre-pandemic\n", "conj_pre = bd[bd['Pandemia_inicio_fin_tratamiento'] == 'Inicio y fin prepandemia']\n", "# Pre-pandemic abandono\n", "pre_abandono = conj_pre[conj_pre['Situacion_tratamiento'] == 'Abandono']\n", "# Pre-pandemic alta\n", "pre_alta = conj_pre[conj_pre['Situacion_tratamiento'] == 'Alta terapéutica']\n", "\n", "# Post-pandemic\n", "# Merging last two classes to balance sets\n", "conj_post = bd[(bd['Pandemia_inicio_fin_tratamiento'] == 'Inicio prepandemia y fin en pandemia') | \n", " (bd['Pandemia_inicio_fin_tratamiento'] == 'inicio y fin en pandemia')]\n", "# Post-pandemic abandono\n", "post_abandono = conj_post[conj_post['Situacion_tratamiento'] == 'Abandono']\n", "# Post-pandemic alta\n", "post_alta = conj_post[conj_post['Situacion_tratamiento'] == 'Alta terapéutica']\n", "\n", "# Concatenate the two data frames and add a new column to distinguish between them. Useful for plots\n", "conj_post['Group'] = 'Post'\n", "conj_pre['Group'] = 'Pre'\n", "combined_pre_post = pd.concat([conj_post, conj_pre])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### First Steps" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "##### Inspecting the dataframes" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(\"PRE\")\n", "print(conj_pre.info())\n", "print (\"-------------------------------\")\n", "print(\"PRE-ABANDONO\")\n", "print(pre_abandono.info())\n", "print (\"-------------------------------\")\n", "print(\"PRE-ALTA\")\n", "print(pre_alta.info())\n", "print (\"-------------------------------\")\n", "\n", "print(\"\\n\\n\\n\")\n", "\n", "print (\"POST\")\n", "print(conj_post.info())\n", "print (\"-------------------------------\")\n", "print(\"POST-ABANDONO\")\n", "print(post_abandono.info())\n", "print (\"-------------------------------\")\n", "print(\"POST-ALTA\")\n", "print(post_alta.info())\n", "print (\"-------------------------------\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "##### Quantifying Null Values" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(f\"Total missing values Age: {bd['Age'].isnull().sum()}\")\n", "print(f\"Total missing values Años_consumo_droga: {bd['Años_consumo_droga'].isnull().sum()}\")\n", "print(f\"Total missing values Risk_stigma: {bd['Risk_stigma'].isnull().sum()}\")\n", "print(f\"Total missing values NumHijos: {bd['NumHijos'].isnull().sum()}\")\n", "\n", "print(\"\\tCONJUNTO PREPANDEMIA\")\n", "print(f\"\\t\\tMissing values Age: {conj_pre['Age'].isnull().sum()}\")\n", "print(f\"\\t\\tMissing values Años_consumo_droga: {conj_pre['Años_consumo_droga'].isnull().sum()}\")\n", "print(f\"\\t\\tMissing values Risk_stigma: {conj_pre['Risk_stigma'].isnull().sum()}\")\n", "print(f\"\\t\\tMissing values NumHijos: {conj_pre['NumHijos'].isnull().sum()}\")\n", "\n", "print(\"\\tCONJUNTO POSTPANDEMIA\")\n", "print(f\"\\t\\tMissing values Age: {conj_post['Age'].isnull().sum()}\")\n", "print(f\"\\t\\tMissing values Años_consumo_droga: {conj_post['Años_consumo_droga'].isnull().sum()}\")\n", "print(f\"\\t\\tMissing values Risk_stigma: {conj_post['Risk_stigma'].isnull().sum()}\")\n", "print(f\"\\t\\tMissing values NumHijos: {conj_post['NumHijos'].isnull().sum()}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Distribution of variables" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "##### Classifying variables into numerical and discrete/categorical " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "disc_atts = ['Education', 'Social_protection', 'Job_insecurity', 'Housing',\n", " 'Alterations_early_childhood_develop', 'Social_inclusion',\n", " 'Risk_stigma', 'Sex', 'NumHijos', 'Smoking', 'Biological_vulnerability',\n", " 'Opiaceos_DxCIE', 'Cannabis_DXCIE', 'BZD_DxCIE', 'Cocaina_DxCIE',\n", " 'Alucinogenos_DXCIE', 'Tabaco_DXCIE', 'FrecuenciaConsumo30Dias',\n", " 'OtrosDx_Psiquiatrico', 'Tx_previos', 'Readmisiones_estudios', 'Nreadmision'\n", " ]\n", "\n", "num_atts = ['Structural_conflic', 'Adherencia_tto_recalc', 'Age', 'Años_consumo_droga', 'Tiempo_tx']" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "##### Distribution of discrete attributes" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "###### Count plots" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "fig, axs = plt.subplots(len(disc_atts), 1, figsize=(15, 5*len(disc_atts)))\n", "plt.subplots_adjust(hspace=0.75, wspace=1.25)\n", "\n", "for i, disc_att in enumerate(disc_atts):\n", " ax = sns.countplot(x=disc_att, data=combined_pre_post, hue=combined_pre_post[['Situacion_tratamiento', 'Group']].apply(tuple, axis=1),\n", " hue_order=[('Abandono', 'Pre'),('Alta terapéutica', 'Pre'), ('Abandono', 'Post'), ('Alta terapéutica', 'Post')],\n", " ax=axs[i])\n", " ax.set_title(disc_att, fontsize=16, fontweight='bold')\n", " ax.get_legend().set_title(\"Groups\")\n", " \n", " # Adding count annotations\n", " for p in ax.patches:\n", " if p.get_label() == '_nolegend_':\n", " ax.annotate(format(p.get_height(), '.0f'), \n", " (p.get_x() + p.get_width() / 2., p.get_height()), \n", " ha = 'center', va = 'center', \n", " xytext = (0, 9), \n", " textcoords = 'offset points')\n", "\n", "# Adjust layout to prevent overlapping titles\n", "plt.tight_layout()\n", "\n", "# Save the figure in SVG format with DPI=600 in the \"./EDA_plots\" folder\n", "plt.savefig('./EDA_plots/countplots.svg', dpi=600, bbox_inches='tight')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "###### Normalized count plots" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Function to plot countplot \n", "def plot_count_perc_norm(i: int, group:int, disc_att:str) -> None:\n", " \"\"\"\n", " group: 1 (all), 2 (pre), 3 (post) \n", " \"\"\"\n", "\n", " # Define data to work with based on group\n", " if group == 1:\n", " df = bd \n", " elif group == 2:\n", " df = conj_pre\n", " elif group == 3:\n", " df = conj_post\n", "\n", " # GOAL: find percentage of each possible category within the total of its situacion_tto subset\n", " # Group data by 'Situacion_tratamiento' and 'Education' and count occurrences\n", " grouped_counts = df.groupby(['Situacion_tratamiento', disc_att]).size().reset_index(name='count')\n", " # Calculate total count for each 'Situacion_tratamiento' group\n", " total_counts = df.groupby('Situacion_tratamiento')[disc_att].count()\n", " # Divide each count by its corresponding total count and calculate percentage\n", " grouped_counts['percentage'] = grouped_counts.apply(lambda row: row['count'] / total_counts[row['Situacion_tratamiento']] * 100, axis=1)\n", " \n", " # Follow the same order in plot as in computations\n", " col_order = grouped_counts[grouped_counts['Situacion_tratamiento'] == 'Abandono'][disc_att].tolist()\n", "\n", " # Create countplot and split each bar into two based on the value of sit_tto\n", " ax = sns.countplot(x=disc_att, hue='Situacion_tratamiento', data=df, order=col_order, ax=axs[i, group-2])\n", "\n", " # Adjust y-axis to represent percentages out of the total count\n", " ax.set_ylim(0, 100)\n", "\n", " percentages = grouped_counts['percentage']\n", " for i, p in enumerate(ax.patches):\n", " # Skip going over the legend values\n", " if p.get_label() == \"_nolegend_\":\n", " # Set height to corresponding percentage and annotate result\n", " height = percentages[i]\n", " p.set_height(height)\n", " ax.annotate(f'{height:.2f}%', (p.get_x() + p.get_width() / 2., height),\n", " ha='center', va='bottom', fontsize=6, color='black', xytext=(0, 5),\n", " textcoords='offset points')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "fig, axs = plt.subplots(len(disc_atts), 2, figsize=(15, 7*len(disc_atts)))\n", "plt.subplots_adjust(hspace=0.75, wspace=1.5)\n", "\n", "for i, disc_att in enumerate(disc_atts):\n", "\n", " # # 1: ALL \n", " # plot_count_perc_norm(i, 1, disc_att)\n", " # axs[i, 0].set_title(\"\\nALL\")\n", " # axs[i, 0].set_xlabel(disc_att, fontweight='bold')\n", " # axs[i, 0].set_ylabel(\"% of total within its Sit_tto group\")\n", " # axs[i, 0].tick_params(axis='x', rotation=90)\n", " \n", " # 2: PRE\n", " plot_count_perc_norm(i, 2, disc_att)\n", " axs[i, 0].set_title(\"\\nPRE\")\n", " axs[i, 0].set_xlabel(disc_att, fontweight='bold')\n", " axs[i, 0].set_ylabel(\"% of total within its Sit_tto group\")\n", " axs[i, 0].tick_params(axis='x', rotation=90)\n", "\n", " # 3: POST\n", " plot_count_perc_norm(i, 3, disc_att)\n", " axs[i, 1].set_title(\"\\nPOST\")\n", " axs[i, 1].set_xlabel(disc_att, fontweight='bold')\n", " axs[i, 1].set_ylabel(\"% of total within its Sit_tto group\")\n", " axs[i, 1].tick_params(axis='x', rotation=90)\n", "\n", " \n", "# Adjust layout to prevent overlapping titles\n", "plt.tight_layout()\n", "\n", "# Save the figure in SVG format with DPI=600 in the \"./EDA_plots\" folder\n", "plt.savefig('./EDA_plots/norm_countplots.svg', dpi=600, bbox_inches='tight')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "##### Distribution of numeric attributes" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "###### Summary statistics" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(bd[num_atts].describe())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "###### Boxplots" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "fig, axs = plt.subplots(len(num_atts), 1, figsize=(12, 5*len(num_atts)))\n", "plt.subplots_adjust(hspace=0.75, wspace=1.5)\n", "\n", "for i, num_att in enumerate(num_atts):\n", " plt.subplot(len(num_atts), 1, i+1)\n", " sns.boxplot(\n", " data=combined_pre_post,\n", " x = num_att,\n", " y = 'Group',\n", " hue='Situacion_tratamiento',\n", " )\n", "\n", "# Adjust layout to prevent overlapping titles\n", "plt.tight_layout()\n", "\n", "# Save the figure in SVG format with DPI=600 in the \"./EDA_plots\" folder\n", "plt.savefig('./EDA_plots/boxplots.svg', dpi=600, bbox_inches='tight')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "###### Histograms" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "fig, axs = plt.subplots(len(num_atts), 3, figsize=(15, 6*len(num_atts)))\n", "plt.subplots_adjust(hspace=0.75, wspace=1.5)\n", "\n", "for i, num_att in enumerate(num_atts):\n", "\n", " # 1: All alcohol patients\n", " sns.histplot(data=bd,x=num_att,bins=15, hue='Situacion_tratamiento', stat='probability', common_norm=False, kde=True,\n", " line_kws={'lw': 5}, alpha = 0.4, ax=axs[i, 0])\n", " axs[i, 0].set_title(f\"\\nDistr. of {num_att} - ALL\")\n", "\n", " # 2: PRE\n", " sns.histplot(data=conj_pre,x=num_att,bins=15, hue='Situacion_tratamiento', stat='probability', common_norm=False, kde=True, \n", " line_kws={'lw': 5}, alpha = 0.4, ax=axs[i, 1])\n", " axs[i, 1].set_title(f\"\\nDistr. of {num_att} - PRE\")\n", "\n", " # Subplot 3: POST\n", " sns.histplot(data=conj_post,x=num_att,bins=15, hue='Situacion_tratamiento', stat='probability', common_norm=False, kde=True, \n", " line_kws={'lw': 5}, alpha = 0.4, ax=axs[i, 2])\n", " axs[i, 2].set_title(f\"\\nDistr. of {num_att} - POST\")\n", "\n", "# Adjust layout to prevent overlapping titles\n", "plt.tight_layout()\n", "\n", "# Save the figure in SVG format with DPI=600 in the \"./EDA_plots\" folder\n", "plt.savefig('./EDA_plots/histograms.svg', dpi=600, bbox_inches='tight')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Correlation Analysis" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "##### Turning binary variables into 0/1 values" ] }, { "cell_type": "code", "execution_count": 56, "metadata": {}, "outputs": [], "source": [ "# --------------------------------------------------------------------------\n", "\n", "# 'Sex'\n", "bd['Sex_REDEF'] = bd['Sex'].map({'Hombre':0, 'Mujer':1})\n", "\n", "# --------------------------------------------------------------------------\n", "\n", "# 'Smoking'\n", "bd['Smoking_REDEF'] = bd['Smoking'].map({'No':0, 'Sí':1})\n", "\n", "# --------------------------------------------------------------------------\n", "\n", "# 'Biological_vulnerability'\n", "bd['Biological_vulnerability_REDEF'] = bd['Biological_vulnerability'].map({'No':0, 'Sí':1})\n", "\n", "# --------------------------------------------------------------------------\n", "\n", "# 'Droga_DxCIE'\n", "bd['Opiaceos_DxCIE_REDEF'] = bd['Opiaceos_DxCIE'].map({'No': 0, 'Sí': 1})\n", "bd['Cannabis_DXCIE_REDEF'] = bd['Cannabis_DXCIE'].map({'No': 0, 'Sí': 1})\n", "bd['BZD_DxCIE_REDEF'] = bd['BZD_DxCIE'].map({'No': 0, 'Sí': 1})\n", "bd['Cocaina_DxCIE_REDEF'] = bd['Cocaina_DxCIE'].map({'No': 0, 'Sí': 1})\n", "bd['Alucinogenos_DXCIE_REDEF'] = bd['Alucinogenos_DXCIE'].map({'No': 0, 'Sí': 1})\n", "bd['Tabaco_DXCIE_REDEF'] = bd['Tabaco_DXCIE'].map({'No': 0, 'Sí': 1})\n", "\n", "# --------------------------------------------------------------------------\n", "\n", "# 'OtrosDx_Psiquiatrico'\n", "bd['OtrosDx_Psiquiatrico_REDEF'] = bd['OtrosDx_Psiquiatrico'].map({'No':0, 'Sí':1})\n", "\n", "# --------------------------------------------------------------------------\n", "\n", "# 'Tx_previos'\n", "bd['Tx_previos_REDEF'] = bd['Tx_previos'].map({'No':0, 'Sí':1})\n", "\n", "# --------------------------------------------------------------------------\n", "\n", "# 'Situacion_tratamiento'\n", "bd['Situacion_tratamiento_REDEF'] = bd['Situacion_tratamiento'].map({'Abandono':0, 'Alta terapéutica':1})\n", "\n", "# --------------------------------------------------------------------------" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "##### Defining groups of variables" ] }, { "cell_type": "code", "execution_count": 57, "metadata": {}, "outputs": [], "source": [ "social_vars = ['Education', 'Social_protection', 'Job_insecurity', 'Housing', 'Alterations_early_childhood_develop', \n", " 'Social_inclusion', 'Risk_stigma', 'Structural_conflic']\n", "ind_vars = ['Age', 'Sex', 'NumHijos', 'Smoking', 'Biological_vulnerability', 'Opiaceos_DxCIE', \n", " 'Cannabis_DXCIE', 'BZD_DxCIE', 'Cocaina_DxCIE', 'Alucinogenos_DXCIE', 'Tabaco_DXCIE', \n", " 'FrecuenciaConsumo30Dias', 'Años_consumo_droga','OtrosDx_Psiquiatrico', 'Tx_previos', 'Adherencia_tto_recalc'] \n", "target_var = 'Situacion_tratamiento'\n", "\n", "# Incluir alcohol?" ] }, { "cell_type": "code", "execution_count": 58, "metadata": {}, "outputs": [], "source": [ "# Columns that are already numeric and we don't need to redefine \n", "no_redef_cols = ['Structural_conflic', 'Age', 'NumHijos', 'Años_consumo_droga', 'Adherencia_tto_recalc']" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# res_vars = ['Tiempo_tx', 'Readmisiones_estudios', 'Periodos_COVID', 'Pandemia_inicio_fin_tratamiento', \n", "# 'Nreadmision', 'Readmisiones_PRECOVID', 'Readmisiones_COVID']" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "##### One-hot encode all categorical variables" ] }, { "cell_type": "code", "execution_count": 59, "metadata": {}, "outputs": [], "source": [ "# Shorten name for 'Alterations_early_childhood_develop'\n", "alterations_mapping = {\n", " 'No alterations (first exposure at 11 or more years)' : 'No alterations',\n", " 'Alterations (first exposure before 11 years old)': 'Alterations',\n", " 'Unknown': 'Unknown'\n", "}\n", "\n", "bd['Alterations_early_childhood_develop_REDEF'] = bd['Alterations_early_childhood_develop'].map(alterations_mapping)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Original approach\n", "one_hot_vars = ['Education', 'Social_protection', 'Job_insecurity', 'Housing', 'Alterations_early_childhood_develop']\n", "\n", "social_vars = ['Education', 'Social_protection', 'Job_insecurity', 'Housing', 'Alterations_early_childhood_develop', \n", " 'Social_inclusion', 'Risk_stigma', 'Structural_conflic']\n", "ind_vars = ['Age', 'Sex', 'NumHijos', 'Smoking', 'Biological_vulnerability', 'Opiaceos_DxCIE', \n", " 'Cannabis_DXCIE', 'BZD_DxCIE', 'Cocaina_DxCIE', 'Alucinogenos_DXCIE', 'Tabaco_DXCIE', \n", " 'FrecuenciaConsumo30Dias', 'Años_consumo_droga','OtrosDx_Psiquiatrico', 'Tx_previos', 'Adherencia_tto_recalc'] \n", "target_var = 'Situacion_tratamiento'" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Specify columns to one hot encode; empty list otherwise\n", "one_hot_vars = ['Droga_Ppal_REC', 'Sexo_x_Hijos', 'Education',\n", " 'Job_insecurity', 'Housing', 'Social_inclusion', 'FrecuenciaConsumo30Dias'] \n", "\n", "one_hots_vars_prefix = {\n", " 'Droga_Ppal_REC': 'DrogP',\n", " 'Sexo_x_Hijos': 'SexHij',\n", " 'Education': 'Ed',\n", " 'Job_insecurity': 'JobIn',\n", " 'Housing': 'Hous', \n", " 'Social_inclusion': 'SocInc',\n", " 'FrecuenciaConsumo30Dias': 'Frec30',\n", "}\n", "\n", "one_hot_cols_dic = {}\n", "\n", "for one_hot_var in one_hot_vars:\n", " # Create one hot encoding version of attribute and concatenate new columns to main df\n", " encoded_var = pd.get_dummies(bd[one_hot_var], prefix=one_hots_vars_prefix[one_hot_var])\n", " bd = pd.concat([bd, encoded_var], axis=1)\n", " one_hot_cols_dic[one_hot_var] = encoded_var.columns.tolist()\n", "\n", "print(one_hot_cols_dic['FrecuenciaConsumo30Dias'])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "###### Defining final version of columns of interest" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "soc_vars_enc = []\n", "for soc_var in social_vars:\n", " # If no need to redefine, append directly\n", " if soc_var in no_redef_cols:\n", " soc_vars_enc.append(soc_var)\n", " # If need to redefine\n", " else:\n", " # Check if it was one-hot encoded\n", " if soc_var in one_hot_vars:\n", " # Append all one hot columns\n", " soc_vars_enc = soc_vars_enc + one_hot_cols_dic[soc_var]\n", " # If not, use redefined version through mapping\n", " else:\n", " soc_vars_enc.append(soc_var + '_REDEF')\n", "\n", "ind_vars_enc = []\n", "for ind_var in ind_vars:\n", " # If no need to redefine, append directly\n", " if ind_var in no_redef_cols:\n", " ind_vars_enc.append(ind_var)\n", " # If need to redefine\n", " else:\n", " # Check if it was one-hot encoded\n", " if ind_var in one_hot_vars:\n", " # Append all one hot columns\n", " ind_vars_enc = ind_vars_enc + one_hot_cols_dic[ind_var]\n", " # If not, use redefined version through mapping\n", " else:\n", " ind_vars_enc.append(ind_var + '_REDEF')\n", "\n", "# Final version of columns we need to use for correlation analysis\n", "corr_cols = soc_vars_enc + ind_vars_enc" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "###### Update main dfs" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Pre-pandemic\n", "conj_pre = bd[bd['Pandemia_inicio_fin_tratamiento'] == 'Inicio y fin prepandemia']\n", "# Pre-pandemic abandono\n", "pre_abandono = conj_pre[conj_pre['Situacion_tratamiento'] == 'Abandono']\n", "# Pre-pandemic alta\n", "pre_alta = conj_pre[conj_pre['Situacion_tratamiento'] == 'Alta terapéutica']\n", "\n", "# Post-pandemic\n", "# Merging last two classes to balance sets\n", "conj_post = bd[(bd['Pandemia_inicio_fin_tratamiento'] == 'Inicio prepandemia y fin en pandemia') | \n", " (bd['Pandemia_inicio_fin_tratamiento'] == 'inicio y fin en pandemia')]\n", "# Post-pandemic abandono\n", "post_abandono = conj_post[conj_post['Situacion_tratamiento'] == 'Abandono']\n", "# Post-pandemic alta\n", "post_alta = conj_post[conj_post['Situacion_tratamiento'] == 'Alta terapéutica']" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "##### Plotting a correlation heatmap" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def plot_heatmap(sit_tto: int, group:int) -> None:\n", " \"\"\"\n", " sit_tto: 1 (include it as another var), 2 (only abandono), 3 (only alta)\n", " group: 1 (all alcohol patients), 2 (pre), 3 (post)\n", " \"\"\"\n", "\n", " # Define columns based on sit_tto arg\n", " if sit_tto == 1:\n", " # Include target as another variable\n", " cols = [target_var + '_REDEF'] + corr_cols\n", " else:\n", " cols = corr_cols\n", " \n", " # Title plot and select datat based on group and sit_tto\n", " if group == 1:\n", " plot_title = \"Correl Matrix - ALL\"\n", " if sit_tto == 1:\n", " bd_ca = bd[cols]\n", " elif sit_tto == 2:\n", " bd_ca = bd[bd['Situacion_tratamiento'] == 'Abandono'][cols]\n", " elif sit_tto == 3:\n", " bd_ca = bd[bd['Situacion_tratamiento'] == 'Alta terapéutica'][cols]\n", " elif group == 2:\n", " plot_title = \"Correl Matrix - PRE\"\n", " if sit_tto == 1: \n", " bd_ca = conj_pre[cols]\n", " elif sit_tto == 2:\n", " bd_ca = pre_abandono[cols]\n", " elif sit_tto == 3:\n", " bd_ca = pre_alta[cols]\n", " elif group == 3:\n", " plot_title = \"Correl Matrix - POST\"\n", " if sit_tto == 1: \n", " bd_ca = conj_post[cols]\n", " elif sit_tto == 2:\n", " bd_ca = post_abandono[cols]\n", " elif sit_tto == 3:\n", " bd_ca = post_alta[cols]\n", " \n", " # Complete title\n", " if sit_tto == 2:\n", " plot_title += \" - ABANDONO\"\n", " elif sit_tto == 3:\n", " plot_title += \" - ALTA\"\n", "\n", " corr_matrix = bd_ca.corr()\n", "\n", " # Create a mask for the upper triangle\n", " mask = np.triu(np.ones_like(corr_matrix, dtype=bool))\n", "\n", " # Create heatmap correlation matrix\n", " dataplot = sns.heatmap(corr_matrix, mask=mask, xticklabels=cols, yticklabels=cols, cmap=\"coolwarm\", vmin=-1, vmax=1, annot=True, fmt=\".2f\", annot_kws={\"size\": 4})\n", "\n", " # Group ind vs social vars by color and modify tick label names\n", " for tick_label in dataplot.axes.xaxis.get_ticklabels():\n", " if tick_label.get_text() in ind_vars_enc:\n", " tick_label.set_color('green')\n", " elif tick_label.get_text() in soc_vars_enc:\n", " tick_label.set_color('purple') \n", " for tick_label in dataplot.axes.yaxis.get_ticklabels():\n", " if tick_label.get_text() in ind_vars_enc:\n", " tick_label.set_color('green')\n", " elif tick_label.get_text() in soc_vars_enc:\n", " tick_label.set_color('purple') \n", "\n", " # Increase the size of xtick labels\n", " # dataplot.tick_params(axis='x', labelsize=12)\n", "\n", " # Increase the size of ytick labels\n", " # dataplot.tick_params(axis='y', labelsize=12)\n", "\n", " # Add legend and place it in lower left \n", " plt.legend(handles=[\n", " plt.Line2D([0], [0], marker='o', color='w', label='Social Factors', markerfacecolor='purple', markersize=10),\n", " plt.Line2D([0], [0], marker='o', color='w', label='Individual Factors', markerfacecolor='green', markersize=10)\n", " ], bbox_to_anchor=(-0.1, -0.1), fontsize = 20)\n", "\n", " plt.title(\"\\n\\n\" + plot_title, fontdict={'fontsize': 30, 'fontweight': 'bold'})" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "###### Original approach (all categorical mapped to integers)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "fig, axs = plt.subplots(3, 3, figsize=(50, 50))\n", "plt.subplots_adjust(hspace=0.75, wspace=2)\n", "\n", "# Go through possible values for 'Situacion_tratamiento' and 'Group'\n", "for sit_tto in range(1,4):\n", " for group in range(1,4):\n", " plt.subplot(3, 3, 3*(sit_tto-1) + group) # Calculate the subplot position dynamically\n", " plot_heatmap(sit_tto, group)\n", " \n", "# Adjust layout to prevent overlapping titles\n", "plt.tight_layout()\n", "\n", "# Save the figure in SVG format in the \"./EDA_plots\" folder\n", "plt.savefig('./EDA_plots/heatmaps_original.svg', dpi=550, bbox_inches='tight')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "###### One-hot encoding approach" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "fig, axs = plt.subplots(3, 3, figsize=(50, 50))\n", "plt.subplots_adjust(hspace=0.75, wspace=2)\n", "\n", "# Go through possible values for 'Situacion_tratamiento' and 'Group'\n", "for sit_tto in range(1,4):\n", " for group in range(1,4):\n", " plt.subplot(3, 3, 3*(sit_tto-1) + group) # Calculate the subplot position dynamically\n", " plot_heatmap(sit_tto, group)\n", " \n", "# Adjust layout to prevent overlapping titles\n", "plt.tight_layout()\n", "\n", "# Save the figure in SVG format in the \"./EDA_plots\" folder\n", "plt.savefig('./EDA_plots/heatmaps_one_hot.svg', dpi=550, bbox_inches='tight')" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.2" } }, "nbformat": 4, "nbformat_minor": 2 }