Commit 9fa990e0 authored by Joaquin Torres's avatar Joaquin Torres

ROC curves looking good

parent f8e93c2e
...@@ -216,25 +216,23 @@ if __name__ == "__main__": ...@@ -216,25 +216,23 @@ if __name__ == "__main__":
# Append the interpolated TPR and AUC for this fold # Append the interpolated TPR and AUC for this fold
tprs.append(interp_tpr) tprs.append(interp_tpr)
aucs.append(roc_display.roc_auc) aucs.append(roc_display.roc_auc)
# Plot the diagonal line representing random guessing # Plot the diagonal line representing random guessing
axes[model_idx].plot([0, 1], [0, 1], linestyle='--', lw=2, color='r', alpha=.8) axes[model_idx].plot([0, 1], [0, 1], linestyle='--', lw=2, color='r', alpha=.8, label='Random guessing')
# Compute the mean of the TPRs # Compute the mean of the TPRs
mean_tpr = np.mean(tprs, axis=0) mean_tpr = np.mean(tprs, axis=0)
mean_tpr[-1] = 1.0 mean_tpr[-1] = 1.0
mean_auc = auc(mean_fpr, mean_tpr) # Calculate the mean AUC mean_auc = auc(mean_fpr, mean_tpr) # Calculate the mean AUC
# Plot the mean ROC curve with a thicker line and distinct color # Plot the mean ROC curve with a thicker line and distinct color
axes[model_idx].plot(mean_fpr, mean_tpr, color='b', lw=4, axes[model_idx].plot(mean_fpr, mean_tpr, color='b', lw=4,
label=r'Mean ROC (AUC = %0.2f)' % mean_auc, alpha=.8) label=r'Mean ROC (AUC = %0.2f)' % mean_auc, alpha=.8)
# Set plot limits and title # Set plot limits and title
axes[model_idx].set(xlim=[-0.05, 1.05], ylim=[-0.05, 1.05], axes[model_idx].set(xlim=[-0.05, 1.05], ylim=[-0.05, 1.05],
title=f"ROC Curve - {model_name} ({group}-{method_names[j]})") title=f"ROC Curve - {model_name} ({group}-{method_names[j]})")
axes[model_idx].legend(loc="lower right") axes[model_idx].legend(loc="lower right")
# Store the DataFrame in the dictionary with a unique key for each sheet # Store the DataFrame in the dictionary with a unique key for each sheet
sheet_name = f"{group}_{method_names[j]}" sheet_name = f"{group}_{method_names[j]}"
scores_sheets[sheet_name] = scores_df scores_sheets[sheet_name] = scores_df
# Adjust layout and save/show figure # Adjust layout and save figure
plt.tight_layout() plt.tight_layout()
plt.savefig(f'./output_cv_metrics/curves/{group}_{method_names[j]}.svg', format='svg', dpi=500) plt.savefig(f'./output_cv_metrics/curves/{group}_{method_names[j]}.svg', format='svg', dpi=500)
plt.close(fig) plt.close(fig)
......
This source diff could not be displayed because it is too large. You can view the blob instead.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment