Commit 1e2a171d authored by Joaquin Torres's avatar Joaquin Torres

regenerated plots and modified code to include CM as well

parent 6975fa04
...@@ -17,6 +17,7 @@ from sklearn.tree import DecisionTreeClassifier ...@@ -17,6 +17,7 @@ from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import RocCurveDisplay, roc_curve from sklearn.metrics import RocCurveDisplay, roc_curve
from sklearn.metrics import PrecisionRecallDisplay, precision_recall_curve from sklearn.metrics import PrecisionRecallDisplay, precision_recall_curve
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
# -------------------------------------------------------------------------------------------------------- # --------------------------------------------------------------------------------------------------------
# Reading test data # Reading test data
...@@ -250,7 +251,7 @@ if __name__ == "__main__": ...@@ -250,7 +251,7 @@ if __name__ == "__main__":
# Scores df # Scores df
scores_df = pd.DataFrame(index=models.keys(), columns=scorings.keys()) scores_df = pd.DataFrame(index=models.keys(), columns=scorings.keys())
# Create a figure for all models in this group-method # Create a figure for all models in this group-method
fig, axes = plt.subplots(len(models), 2, figsize=(8, 8 * len(models))) fig, axes = plt.subplots(len(models), 3, figsize=(10, 8 * len(models)))
if len(models) == 1: # Adjustment if there's only one model (axes indexing issue) if len(models) == 1: # Adjustment if there's only one model (axes indexing issue)
axes = [axes] axes = [axes]
# Evaluate each model # Evaluate each model
...@@ -268,15 +269,21 @@ if __name__ == "__main__": ...@@ -268,15 +269,21 @@ if __name__ == "__main__":
# Calculate precision-recall curve # Calculate precision-recall curve
precision, recall, _ = precision_recall_curve(y_test, y_score, pos_label=model.classes_[1]) precision, recall, _ = precision_recall_curve(y_test, y_score, pos_label=model.classes_[1])
pr_display = PrecisionRecallDisplay(precision=precision, recall=recall).plot(ax=axes[model_idx][1]) pr_display = PrecisionRecallDisplay(precision=precision, recall=recall).plot(ax=axes[model_idx][1])
# Get confusion matrix plot
y_pred = model.predict(X_test)
cm = confusion_matrix(y_test, y_pred)
ConfusionMatrixDisplay(cm).plot(ax=axes[model_idx[2]])
# Give name to plots
axes[model_idx][0].set_title(f'ROC Curve for {model_name}') axes[model_idx][0].set_title(f'ROC Curve for {model_name}')
axes[model_idx][1].set_title(f'PR Curve for {model_name}') axes[model_idx][1].set_title(f'PR Curve for {model_name}')
axes[model_idx][2].set_title(f'CM for {model_name}')
# Evaluate at each of the scores of interest # Evaluate at each of the scores of interest
for score_name, scorer in scorings.items(): for score_name, scorer in scorings.items():
score_value = scorer(model, X_test, y_test) score_value = scorer(model, X_test, y_test)
scores_df.at[model_name, score_name] = score_value scores_df.at[model_name, score_name] = score_value
# Adjust layout and save/show figure # Adjust layout and save/show figure
plt.tight_layout() plt.tight_layout()
plt.savefig(f'./test_results/roc_pr_curves/{group}_{method_names[j]}.svg', format='svg', dpi=500) plt.savefig(f'./test_results/aux_plots/{group}_{method_names[j]}.svg', format='svg', dpi=500)
plt.close(fig) plt.close(fig)
# Store the DataFrame in the dictionary with a unique key for each sheet # Store the DataFrame in the dictionary with a unique key for each sheet
sheet_name = f"{group}_{method_names[j]}" sheet_name = f"{group}_{method_names[j]}"
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment