diff --git a/model_selection/cv_metric_gen.py b/model_selection/cv_metric_gen.py index fea2f369733369f93cb6aee3d6a178bec616eb74..69184608f008dc0c218bcb206f5457b582f89c8a 100644 --- a/model_selection/cv_metric_gen.py +++ b/model_selection/cv_metric_gen.py @@ -16,7 +16,7 @@ from sklearn.svm import SVC from sklearn.linear_model import LogisticRegression from sklearn.tree import DecisionTreeClassifier from sklearn.model_selection import StratifiedKFold, cross_validate -from sklearn.metrics import RocCurveDisplay, roc_curve, auc +from sklearn.metrics import RocCurveDisplay, auc from sklearn.metrics import PrecisionRecallDisplay, precision_recall_curve import matplotlib.pyplot as plt import ast # String to dictionary @@ -185,49 +185,64 @@ if __name__ == "__main__": # Scores df -> one column per cv split, one row for each model-metric scores_df = pd.DataFrame(columns=range(1,11), index=[f"{model_name}_{metric_name}" for model_name in models.keys() for metric_name in scorings.keys()]) # Create a figure for all models in this group-method - fig, axes = plt.subplots(len(models), 1, figsize=(10, 8 * len(models))) + fig, axes = plt.subplots(len(models), 2, figsize=(10, 8 * len(models))) if len(models) == 1: # Adjustment if there's only one model (axes indexing issue) axes = [axes] # Metric generation for each model for model_idx, (model_name, model) in enumerate(models.items()): - print(f"{group}-{method_names[j]}-{model_name}") - # Retrieve cv scores for our metrics of interest - scores = cross_validate(model, X_train, y_train, scoring=scorings, cv=cv, return_train_score=True, n_jobs=10) - # Save results of each fold - for metric_name in scorings.keys(): - scores_df.loc[model_name + f'_{metric_name}']=list(np.around(np.array(scores[f"test_{metric_name}"]),4)) - # ---------- Generate ROC curves ---------- - mean_fpr = np.linspace(0, 1, 100) - tprs, aucs = [], [] - cmap = plt.get_cmap('tab10') # Colormap - # Loop through each fold in the cross-validation (redoing cv for simplicity) - for fold_idx, (train, test) in enumerate(cv.split(X_train, y_train)): - # Fit the model on the training data - model.fit(X_train[train], y_train[train]) - # Use RocCurveDisplay to generate the ROC curve - roc_display = RocCurveDisplay.from_estimator(model, X_train[test], y_train[test], - name=f"ROC fold {fold_idx}", alpha=0.6, lw=2, - ax=axes[model_idx], color=cmap(fold_idx % 10)) - # Interpolate the true positive rates to get a smooth curve - interp_tpr = np.interp(mean_fpr, roc_display.fpr, roc_display.tpr) - interp_tpr[0] = 0.0 - # Append the interpolated TPR and AUC for this fold - tprs.append(interp_tpr) - aucs.append(roc_display.roc_auc) - # Plot the diagonal line representing random guessing - axes[model_idx].plot([0, 1], [0, 1], linestyle='--', lw=2, color='r', alpha=.8, label='Random guessing') - # Compute the mean of the TPRs - mean_tpr = np.mean(tprs, axis=0) - mean_tpr[-1] = 1.0 - mean_auc = auc(mean_fpr, mean_tpr) # Calculate the mean AUC - # Plot the mean ROC curve with a thicker line and distinct color - axes[model_idx].plot(mean_fpr, mean_tpr, color='b', lw=4, - label=r'Mean ROC (AUC = %0.2f)' % mean_auc, alpha=.8) - # Set plot limits and title - axes[model_idx].set(xlim=[-0.05, 1.05], ylim=[-0.05, 1.05], - title=f"ROC Curve - {model_name} ({group}-{method_names[j]})") - axes[model_idx].legend(loc="lower right") - # ---------- END ROC curves Generation ---------- + if model_name == 'XGB': + print(f"{group}-{method_names[j]}-{model_name}") + # Retrieve cv scores for our metrics of interest + scores = cross_validate(model, X_train, y_train, scoring=scorings, cv=cv, return_train_score=True, n_jobs=10) + # Save results of each fold + for metric_name in scorings.keys(): + scores_df.loc[model_name + f'_{metric_name}']=list(np.around(np.array(scores[f"test_{metric_name}"]),4)) + # ---------------------------------------- Generate curves ---------------------------------------- + mean_fpr = np.linspace(0, 1, 100) + tprs, aucs = [], [] + mean_recall = np.linspace(0, 1, 100) + precisions, pr_aucs = [], [] + cmap = plt.get_cmap('tab10') # Colormap + # Loop through each fold in the cross-validation + for fold_idx, (train, test) in enumerate(cv.split(X_train, y_train)): + # Fit the model on the training data + model.fit(X_train[train], y_train[train]) + # Generate ROC curve for the fold + roc_display = RocCurveDisplay.from_estimator(model, X_train[test], y_train[test], + name=f"ROC fold {fold_idx}", alpha=0.6, lw=2, + ax=axes[model_idx][0], color=cmap(fold_idx % 10)) + interp_tpr = np.interp(mean_fpr, roc_display.fpr, roc_display.tpr) + interp_tpr[0] = 0.0 + tprs.append(interp_tpr) + aucs.append(roc_display.roc_auc) + # Generate Precision-Recall curve for the fold + pr_display = PrecisionRecallDisplay.from_estimator(model, X_train[test], y_train[test], + name=f"PR fold {fold_idx}", alpha=0.6, lw=2, + ax=axes[model_idx][1], color=cmap(fold_idx % 10)) + interp_precision = np.interp(mean_recall, pr_display.recall[::-1], pr_display.precision[::-1]) + precisions.append(interp_precision) + pr_aucs.append(pr_display.average_precision) + # Plot diagonal line for random guessing in ROC curve + axes[model_idx][0].plot([0, 1], [0, 1], linestyle='--', lw=2, color='r', alpha=.8, label='Random guessing') + # Compute mean ROC curve + mean_tpr = np.mean(tprs, axis=0) + mean_tpr[-1] = 1.0 + mean_auc = auc(mean_fpr, mean_tpr) + axes[model_idx][0].plot(mean_fpr, mean_tpr, color='b', lw=4, label=r'Mean ROC (AUC = %0.2f)' % mean_auc, alpha=.8) + # Set ROC plot limits and title + axes[model_idx][0].set(xlim=[-0.05, 1.05], ylim=[-0.05, 1.05], title=f"ROC Curve - {model_name} ({group}-{method_names[j]})") + axes[model_idx][0].legend(loc="lower right") + # Compute mean Precision-Recall curve + mean_precision = np.mean(precisions, axis=0) + mean_pr_auc = np.mean(pr_aucs) + axes[model_idx][1].plot(mean_recall, mean_precision, color='b', lw=4, label=r'Mean PR (AUC = %0.2f)' % mean_pr_auc, alpha=.8) + # # Plot baseline precision (proportion of positive samples) + # baseline = np.sum(y_train) / len(y_train) + # axes[model_idx][1].plot([0, 1], [baseline, baseline], linestyle='--', lw=2, color='r', alpha=.8, label='Baseline') + # Set Precision-Recall plot limits and title + axes[model_idx][1].set(xlim=[-0.05, 1.05], ylim=[-0.05, 1.05], title=f"Precision-Recall Curve - {model_name} ({group}-{method_names[j]})") + axes[model_idx][1].legend(loc="lower right") + # ---------------------------------------- End Generate Curves ---------------------------------------- # Store the DataFrame in the dictionary with a unique key for each sheet sheet_name = f"{group}_{method_names[j]}" scores_sheets[sheet_name] = scores_df @@ -239,7 +254,4 @@ if __name__ == "__main__": with pd.ExcelWriter('./output_cv_metrics/metrics.xlsx') as writer: for sheet_name, data in scores_sheets.items(): data.to_excel(writer, sheet_name=sheet_name) - print("Successful cv metric generation for tuned models") - - - \ No newline at end of file + print("Successful cv metric generation for tuned models") \ No newline at end of file diff --git a/model_selection/output_cv_metrics/curves/pre_ORIG.svg b/model_selection/output_cv_metrics/curves/pre_ORIG.svg index 79ee48f93d83a76f544bea4a44c19c9f8a17141d..3c10346ebd18d5c2b0b139471663d8ea078a9baf 100644 --- a/model_selection/output_cv_metrics/curves/pre_ORIG.svg +++ b/model_selection/output_cv_metrics/curves/pre_ORIG.svg @@ -6,7 +6,7 @@ - 2024-05-23T12:10:22.372119 + 2024-05-23T12:46:28.660631 image/svg+xml @@ -30,10 +30,10 @@ z - @@ -41,17 +41,17 @@ z - - + - + - + - + - + - + - + - + - + - + - + - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + - - - - - + + + - + - + - + - - + + - + - + - + - - + + - + - + - + - - + + - + - + - + - - + + - + - + - + - - + + - + - + - + - + - + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - - + - - + - - + - - - + + + - + + - - + @@ -1160,44 +6707,45 @@ z - - - - - - - - - - - - - + + + + + + + + + + + + + + - - + - - + - - - + + + - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - @@ -1402,20 +6937,20 @@ L 492.535 394.494062 - - + + - - + - - - + + + @@ -1436,20 +6971,20 @@ L 492.535 409.172187 - - + + - - + - - - + + + - - + + - - + - - - + + + @@ -1538,20 +7073,47 @@ L 492.535 438.528437 - - + + - - + - - - + + + + + + @@ -1572,20 +7134,20 @@ L 492.535 453.206562 - - + + - - + - - - + + + @@ -1606,32 +7168,20 @@ L 492.535 467.884687 - - + + - - + - - - - - - + + + @@ -1652,20 +7202,20 @@ z - - + + - - + - - - + + + @@ -1686,20 +7236,20 @@ L 492.535 497.240937 - - + + - - + - - - + + + @@ -1720,20 +7270,20 @@ L 492.535 511.919062 - - + + - - + - + - + - - + - - - + + + - - + + - - - + + - - - + + + - + - + - + - - + + - + - + - + - - + + - + - + - + - - + + - + - + - + - - + + - + - + - + - - + + - + - + - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - - - + + + - + - + - + - - + + - + - + - + - - + + - + - + - + - - + + - + - + - + - - + + - + - + - + - - + + - + - + - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - - + - - + - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - - - + + - - - + + + - + - + - + - - + + - + - + - + - - + + - + - + - + - - + + - + - + - + - - + + - + - + - + - - + + - + - + - + @@ -2203,91 +15170,91 @@ z - - - + + + - + - + - + - - + + - + - + - + - - + + - + - + - + - - + + - + - + - + - - + + - + - + - + - - + + - + - + - + @@ -2295,121 +15262,121 @@ z - - + - - + - - + - - + - - - + + - - - + + + - + - + - + - - + + - + - + - + - - + + - + - + - + - - + + - + - + - + - - + + - + - + - + - - + + - + - + - + @@ -2417,91 +15384,91 @@ z - - - + + + - + - + - + - - + + - + - + - + - - + + - + - + - + - - + + - + - + - + - - + + - + - + - + - - + + - + - + - + @@ -2509,121 +15476,121 @@ z - - + - - + - - + - - + - - - + + - - - + + + - + - + - + - - + + - + - + - + - - + + - + - + - + - - + + - + - + - + - - + + - + - + - + - - + + - + - + - + @@ -2631,91 +15598,91 @@ z - - - + + + - + - + - + - - + + - + - + - + - - + + - + - + - + - - + + - + - + - + - - + + - + - + - + - - + + - + - + - + @@ -2723,121 +15690,121 @@ z - - + - - + - - + - - + - - - + + - - - + + + - + - + - + - - + + - + - + - + - - + + - + - + - + - - + + - + - + - + - - + + - + - + - + - - + + - + - + - + @@ -2845,91 +15812,91 @@ z - - - + + + - + - + - + - - + + - + - + - + - - + + - + - + - + - - + + - + - + - + - - + + - + - + - + - - + + - + - + - + @@ -2937,121 +15904,121 @@ z - - + - - + - - + - - + - - - + + - - - + + + - + - + - + - - + + - + - + - + - - + + - + - + - + - - + + - + - + - + - - + + - + - + - + - - + + - + - + - + @@ -3059,91 +16026,91 @@ z - - - + + + - + - + - + - - + + - + - + - + - - + + - + - + - + - - + + - + - + - + - - + + - + - + - + - - + + - + - + - + @@ -3151,119 +16118,119 @@ z - - + - - + - - + - - + - - - + + - - - + + + - + - + - + - - + + - + - + - + - - + + - + - + - + - - + + - + - + - + - - + + - + - + - + - - + + - + - + @@ -3273,91 +16240,91 @@ z - - - + + + - + - + - + - - + + - + - + - + - - + + - + - + - + - - + + - + - + - + - - + + - + - + - + - - + + - + - + - + @@ -3365,31 +16332,34 @@ z - - + - + - - + - - + - - + + + + + diff --git a/model_selection/output_cv_metrics/metrics.xlsx b/model_selection/output_cv_metrics/metrics.xlsx index 003029e62d816579d0620a65a15a76e07d107fef..8680af8a26424fb75abaec8df07c2d3d8c0bcb15 100644 Binary files a/model_selection/output_cv_metrics/metrics.xlsx and b/model_selection/output_cv_metrics/metrics.xlsx differ