diff --git a/model_selection/cv_metric_gen.py b/model_selection/cv_metric_gen.py index e8be09959b6593e9049e8abdae0a99780d45f061..7546601a3558c6dd78a3dbf83d82d4e11987946c 100644 --- a/model_selection/cv_metric_gen.py +++ b/model_selection/cv_metric_gen.py @@ -209,9 +209,8 @@ if __name__ == "__main__": # Curve generation setup mean_fpr = np.linspace(0, 1, 100) tprs, aucs = [], [] - recall_points = np.linspace(0, 1, 100) - all_precisions = [] - pr_aucs = [] + mean_recall = np.linspace(0, 1, 100) + precisions, pr_aucs = [], [] cmap = plt.get_cmap('tab10') # Colormap # Initialize storage for scores for each fold fold_scores = {metric_name: [] for metric_name in scorings.keys()} @@ -241,14 +240,18 @@ if __name__ == "__main__": pr_display = PrecisionRecallDisplay.from_estimator(model, X_test_fold, y_test_fold, name=f"PR fold {fold_idx}", alpha=0.6, lw=2, ax=axes[model_idx][1], color=cmap(fold_idx % 10)) + # Reverse the recall and precision arrays for interpolation + recall_for_interp = pr_display.recall[::-1] + precision_for_interp = pr_display.precision[::-1] - precision, recall = pr_display.precision, pr_display.recall - pr_aucs.append(pr_display.average_precision) - axes[model_idx][1].plot(recall, precision, alpha=0.6, lw=2, label=f"PR fold {fold_idx} (AP = {pr_display.average_precision:.2f})", color=cmap(fold_idx % 10)) + # Handle the edge case where recall_for_interp has duplicates, which can break np.interp + recall_for_interp, unique_indices = np.unique(recall_for_interp, return_index=True) + precision_for_interp = precision_for_interp[unique_indices] - # Store the precision values for each recall point - all_precisions.append(np.interp(recall_points, recall[::-1], precision[::-1])) - + # Interpolate precision + interp_precision = np.interp(mean_recall, recall_for_interp, precision_for_interp) + precisions.append(interp_precision) + pr_aucs.append(pr_display.average_precision) # Plot diagonal line for random guessing in ROC curve axes[model_idx][0].plot([0, 1], [0, 1], linestyle='--', lw=2, color='r', alpha=.8, label='Random guessing') # Compute mean ROC curve @@ -259,11 +262,11 @@ if __name__ == "__main__": # Set ROC plot limits and title axes[model_idx][0].set(xlim=[-0.05, 1.05], ylim=[-0.05, 1.05], title=f"ROC Curve - {model_name} ({group}-{method_names[j]})") axes[model_idx][0].legend(loc="lower right") - + # Compute mean Precision-Recall curve - mean_precision = np.mean(all_precisions, axis=0) + mean_precision = np.mean(precisions, axis=0) mean_pr_auc = np.mean(pr_aucs) - axes[model_idx][1].plot(recall_points, mean_precision, color='b', lw=4, label=r'Mean PR (AP = %0.2f)' % mean_pr_auc, alpha=.8) + axes[model_idx][1].plot(mean_recall, mean_precision, color='b', lw=4, label=r'Mean PR (AUC = %0.2f)' % mean_pr_auc, alpha=.8) # Plot baseline precision (proportion of positive samples) baseline = np.sum(y_train) / len(y_train) axes[model_idx][1].plot([0, 1], [baseline, baseline], linestyle='--', lw=2, color='r', alpha=.8, label='Baseline')