Commit 87b69656 authored by Joaquin Torres's avatar Joaquin Torres

Another try

parent b682c08b
......@@ -209,8 +209,9 @@ if __name__ == "__main__":
# Curve generation setup
mean_fpr = np.linspace(0, 1, 100)
tprs, aucs = [], []
mean_recall = np.linspace(0, 1, 100)
precisions, pr_aucs = [], []
recall_points = np.linspace(0, 1, 100)
all_precisions = []
pr_aucs = []
cmap = plt.get_cmap('tab10') # Colormap
# Initialize storage for scores for each fold
fold_scores = {metric_name: [] for metric_name in scorings.keys()}
......@@ -240,18 +241,14 @@ if __name__ == "__main__":
pr_display = PrecisionRecallDisplay.from_estimator(model, X_test_fold, y_test_fold,
name=f"PR fold {fold_idx}", alpha=0.6, lw=2,
ax=axes[model_idx][1], color=cmap(fold_idx % 10))
# Reverse the recall and precision arrays for interpolation
recall_for_interp = pr_display.recall[::-1]
precision_for_interp = pr_display.precision[::-1]
# Handle the edge case where recall_for_interp has duplicates, which can break np.interp
recall_for_interp, unique_indices = np.unique(recall_for_interp, return_index=True)
precision_for_interp = precision_for_interp[unique_indices]
# Interpolate precision
interp_precision = np.interp(mean_recall, recall_for_interp, precision_for_interp)
precisions.append(interp_precision)
precision, recall = pr_display.precision, pr_display.recall
pr_aucs.append(pr_display.average_precision)
axes[model_idx][1].plot(recall, precision, alpha=0.6, lw=2, label=f"PR fold {fold_idx} (AP = {pr_display.average_precision:.2f})", color=cmap(fold_idx % 10))
# Store the precision values for each recall point
all_precisions.append(np.interp(recall_points, recall[::-1], precision[::-1]))
# Plot diagonal line for random guessing in ROC curve
axes[model_idx][0].plot([0, 1], [0, 1], linestyle='--', lw=2, color='r', alpha=.8, label='Random guessing')
# Compute mean ROC curve
......@@ -264,9 +261,9 @@ if __name__ == "__main__":
axes[model_idx][0].legend(loc="lower right")
# Compute mean Precision-Recall curve
mean_precision = np.mean(precisions, axis=0)
mean_precision = np.mean(all_precisions, axis=0)
mean_pr_auc = np.mean(pr_aucs)
axes[model_idx][1].plot(mean_recall, mean_precision, color='b', lw=4, label=r'Mean PR (AUC = %0.2f)' % mean_pr_auc, alpha=.8)
axes[model_idx][1].plot(recall_points, mean_precision, color='b', lw=4, label=r'Mean PR (AP = %0.2f)' % mean_pr_auc, alpha=.8)
# Plot baseline precision (proportion of positive samples)
baseline = np.sum(y_train) / len(y_train)
axes[model_idx][1].plot([0, 1], [baseline, baseline], linestyle='--', lw=2, color='r', alpha=.8, label='Baseline')
......
This source diff could not be displayed because it is too large. You can view the blob instead.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment