Commit 9ad950f3 authored by Joaquin Torres's avatar Joaquin Torres

Fixed cm bar size

parent faf6e24f
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
...@@ -19,6 +19,8 @@ from sklearn.metrics import PrecisionRecallDisplay, precision_recall_curve ...@@ -19,6 +19,8 @@ from sklearn.metrics import PrecisionRecallDisplay, precision_recall_curve
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import ast # String to dictionary import ast # String to dictionary
import seaborn as sns
from mpl_toolkits.axes_grid1 import make_axes_locatable
# -------------------------------------------------------------------------------------------------------- # --------------------------------------------------------------------------------------------------------
# Reading data # Reading data
...@@ -211,8 +213,6 @@ if __name__ == "__main__": ...@@ -211,8 +213,6 @@ if __name__ == "__main__":
scores_df = pd.DataFrame(index=models.keys(), columns=scorings.keys()) scores_df = pd.DataFrame(index=models.keys(), columns=scorings.keys())
# Create a figure for all models in this group-method # Create a figure for all models in this group-method
fig, axes = plt.subplots(len(models), 3, figsize=(10, 8 * len(models))) fig, axes = plt.subplots(len(models), 3, figsize=(10, 8 * len(models)))
if len(models) == 1: # Adjustment if there's only one model (axes indexing issue)
axes = [axes]
# Evaluate each model with test dataset # Evaluate each model with test dataset
for model_idx, (model_name, model) in enumerate(models.items()): for model_idx, (model_name, model) in enumerate(models.items()):
print(f"{group}-{method_names[j]}-{model_name}") print(f"{group}-{method_names[j]}-{model_name}")
...@@ -260,8 +260,18 @@ if __name__ == "__main__": ...@@ -260,8 +260,18 @@ if __name__ == "__main__":
# Compute confusion matrix # Compute confusion matrix
cm = confusion_matrix(y_test, y_pred) cm = confusion_matrix(y_test, y_pred)
# Plot the confusion matrix # Plot the confusion matrix
ConfusionMatrixDisplay(cm).plot(ax=axes[model_idx][2]) cmp = ConfusionMatrixDisplay(cm)
# Deactivate default colorbar
cmp.plot(ax=axes[model_idx][2], colorbar=False, cmap=sns.color_palette("light:b", as_cmap=True))
# Adding custom colorbar using make_axes_locatable
divider = make_axes_locatable(axes[model_idx][2])
cax = divider.append_axes("right", size="5%", pad=0.05)
plt.colorbar(cmp.im_, cax=cax)
axes[model_idx][2].set_title(f'CM for {group}-{method}-{model_name}') axes[model_idx][2].set_title(f'CM for {group}-{method}-{model_name}')
axes[model_idx][2].set_xlabel('Predicted label')
axes[model_idx][2].set_ylabel('True label')
# ---------------------------------------------------------- # ----------------------------------------------------------
# Adjust layout and save/show figure # Adjust layout and save/show figure
plt.tight_layout() plt.tight_layout()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment