Commit faf6e24f authored by Joaquin Torres's avatar Joaquin Torres

Test to check if mean PREC-REC issue fixed

parent 362330ae
...@@ -17,7 +17,7 @@ from sklearn.linear_model import LogisticRegression ...@@ -17,7 +17,7 @@ from sklearn.linear_model import LogisticRegression
from sklearn.tree import DecisionTreeClassifier from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import StratifiedKFold from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import RocCurveDisplay, auc from sklearn.metrics import RocCurveDisplay, auc
from sklearn.metrics import PrecisionRecallDisplay from sklearn.metrics import PrecisionRecallDisplay, precision_recall_curve
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import ast # String to dictionary import ast # String to dictionary
# -------------------------------------------------------------------------------------------------------- # --------------------------------------------------------------------------------------------------------
...@@ -185,13 +185,15 @@ if __name__ == "__main__": ...@@ -185,13 +185,15 @@ if __name__ == "__main__":
} }
# Defining cross-validation protocol # Defining cross-validation protocol
cv = StratifiedKFold(n_splits=10, shuffle=True, random_state=42) cv = StratifiedKFold(n_splits=10, shuffle=True, random_state=42)
# Colormap
cmap = plt.get_cmap('tab10')
# -------------------------------------------------------------------------------------------------------- # --------------------------------------------------------------------------------------------------------
# Metric generation through cv for tuned models3 # Metric generation through cv for tuned models3
# -------------------------------------------------------------------------------------------------------- # --------------------------------------------------------------------------------------------------------
scores_sheets = {} # To store score dfs as sheets in the same excel file scores_sheets = {} # To store score dfs as sheets in the same excel file
for i, group in enumerate(['post']): # ['pre', 'post'] for i, group in enumerate(['pre', 'post']):
for j, method in enumerate(['']): # ['', '', 'over_', 'under_'] for j, method in enumerate(['', '', 'over_', 'under_']):
# Get train dataset based on group and method # Get train dataset based on group and method
X_train = data_dic['X_train_' + method + group] X_train = data_dic['X_train_' + method + group]
y_train = data_dic['y_train_' + method + group] y_train = data_dic['y_train_' + method + group]
...@@ -201,19 +203,16 @@ if __name__ == "__main__": ...@@ -201,19 +203,16 @@ if __name__ == "__main__":
scores_df = pd.DataFrame(columns=range(1,11), index=[f"{model_name}_{metric_name}" for model_name in models.keys() for metric_name in scorings.keys()]) scores_df = pd.DataFrame(columns=range(1,11), index=[f"{model_name}_{metric_name}" for model_name in models.keys() for metric_name in scorings.keys()])
# Create a figure with 2 subplots (roc and pr curves) for each model in this group-method # Create a figure with 2 subplots (roc and pr curves) for each model in this group-method
fig, axes = plt.subplots(len(models), 2, figsize=(10, 8 * len(models))) fig, axes = plt.subplots(len(models), 2, figsize=(10, 8 * len(models)))
if len(models) == 1: # Adjustment if there's only one model (axes indexing issue)
axes = [axes]
# Metric generation for each model # Metric generation for each model
for model_idx, (model_name, model) in enumerate(models.items()): for model_idx, (model_name, model) in enumerate(models.items()):
print(f"{group}-{method_names[j]}-{model_name}") print(f"{group}-{method_names[j]}-{model_name}")
# Curve generation setup
mean_fpr = np.linspace(0, 1, 100)
tprs, aucs = [], []
mean_recall = np.linspace(0, 1, 100)
precisions, pr_aucs = [], []
cmap = plt.get_cmap('tab10') # Colormap
# Initialize storage for scores for each fold # Initialize storage for scores for each fold
fold_scores = {metric_name: [] for metric_name in scorings.keys()} fold_scores = {metric_name: [] for metric_name in scorings.keys()}
# ROC setup
mean_fpr = np.linspace(0, 1, 100)
tprs, aucs = [], []
# PR setup
y_real, y_proba = [], []
# Manually loop through each fold in the cross-validation # Manually loop through each fold in the cross-validation
for fold_idx, (train_idx, test_idx) in enumerate(cv.split(X_train, y_train)): for fold_idx, (train_idx, test_idx) in enumerate(cv.split(X_train, y_train)):
X_train_fold, X_test_fold = X_train[train_idx], X_train[test_idx] X_train_fold, X_test_fold = X_train[train_idx], X_train[test_idx]
...@@ -225,9 +224,8 @@ if __name__ == "__main__": ...@@ -225,9 +224,8 @@ if __name__ == "__main__":
for metric_name, scorer in scorings.items(): for metric_name, scorer in scorings.items():
score = scorer(model, X_test_fold, y_test_fold) score = scorer(model, X_test_fold, y_test_fold)
fold_scores[metric_name].append(score) fold_scores[metric_name].append(score)
# --------------------- END SCORINGS ---------------------------
# --------------------- CURVES --------------------------- # --------------------- CURVES ---------------------------
# Generate ROC curve for the fold # ROC generation for current fold
roc_display = RocCurveDisplay.from_estimator(model, X_test_fold, y_test_fold, roc_display = RocCurveDisplay.from_estimator(model, X_test_fold, y_test_fold,
name=f"ROC fold {fold_idx}", alpha=0.6, lw=2, name=f"ROC fold {fold_idx}", alpha=0.6, lw=2,
ax=axes[model_idx][0], color=cmap(fold_idx % 10)) ax=axes[model_idx][0], color=cmap(fold_idx % 10))
...@@ -235,44 +233,42 @@ if __name__ == "__main__": ...@@ -235,44 +233,42 @@ if __name__ == "__main__":
interp_tpr[0] = 0.0 interp_tpr[0] = 0.0
tprs.append(interp_tpr) tprs.append(interp_tpr)
aucs.append(roc_display.roc_auc) aucs.append(roc_display.roc_auc)
# PR-recall generation for current fold
# Generate Precision-Recall curve for the fold if hasattr(model, "decision_function"):
pr_display = PrecisionRecallDisplay.from_estimator(model, X_test_fold, y_test_fold, y_score = model.decision_function(X_test_fold)
name=f"PR fold {fold_idx}", alpha=0.6, lw=2, else:
ax=axes[model_idx][1], color=cmap(fold_idx % 10)) y_score = model.predict_proba(X_test_fold)[:, 1]
# Reverse the recall and precision arrays for interpolation precision, recall, _ = precision_recall_curve(y_test_fold, y_score)
recall_for_interp = pr_display.recall[::-1] pr_auc = average_precision_score(y_test_fold, y_score)
precision_for_interp = pr_display.precision[::-1] axes[model_idx][1].plot(recall, precision, lw=2, alpha=0.3, label='PR fold %d (AUPRC = %0.2f)' % (fold_idx, pr_auc))
y_real.append(y_test_fold)
# Handle the edge case where recall_for_interp has duplicates, which can break np.interp y_proba.append(y_score)
recall_for_interp, unique_indices = np.unique(recall_for_interp, return_index=True) # Mean ROC Curve
precision_for_interp = precision_for_interp[unique_indices]
# Interpolate precision
interp_precision = np.interp(mean_recall, recall_for_interp, precision_for_interp)
precisions.append(interp_precision)
pr_aucs.append(pr_display.average_precision)
# Plot diagonal line for random guessing in ROC curve
axes[model_idx][0].plot([0, 1], [0, 1], linestyle='--', lw=2, color='r', alpha=.8, label='Random guessing')
# Compute mean ROC curve
mean_tpr = np.mean(tprs, axis=0) mean_tpr = np.mean(tprs, axis=0)
mean_tpr[-1] = 1.0 mean_tpr[-1] = 1.0
mean_auc = auc(mean_fpr, mean_tpr) mean_auc = auc(mean_fpr, mean_tpr)
axes[model_idx][0].plot(mean_fpr, mean_tpr, color='b', lw=4, label=r'Mean ROC (AUC = %0.2f)' % mean_auc, alpha=.8) axes[model_idx][0].plot(mean_fpr, mean_tpr, color='b', lw=4, label=r'Mean ROC (AUC = %0.2f)' % mean_auc, alpha=.8)
# Plot diagonal line for random guessing in ROC curve
axes[model_idx][0].plot([0, 1], [0, 1], linestyle='--', lw=2, color='r', alpha=.8, label='Random guessing')
# Set ROC plot limits and title # Set ROC plot limits and title
axes[model_idx][0].set(xlim=[-0.05, 1.05], ylim=[-0.05, 1.05], title=f"ROC Curve - {model_name} ({group}-{method_names[j]})") axes[model_idx][0].set(xlim=[-0.05, 1.05], ylim=[-0.05, 1.05], title=f"ROC Curve - {model_name} ({group}-{method_names[j]})")
axes[model_idx][0].legend(loc="lower right") axes[model_idx][0].legend(loc="lower right", fontsize='small')
# Mean PR Curve
# Compute mean Precision-Recall curve y_real = np.concatenate(y_real)
mean_precision = np.mean(precisions, axis=0) y_proba = np.concatenate(y_proba)
mean_pr_auc = np.mean(pr_aucs) precision, recall, _ = precision_recall_curve(y_real, y_proba)
axes[model_idx][1].plot(mean_recall, mean_precision, color='b', lw=4, label=r'Mean PR (AUC = %0.2f)' % mean_pr_auc, alpha=.8) axes[model_idx][1].plot(recall, precision, color='b', label=r'Mean PR (AUPRC = %0.2f)' % (average_precision_score(y_real, y_proba)),
lw=4, alpha=.8)
# Plot baseline precision (proportion of positive samples) # Plot baseline precision (proportion of positive samples)
baseline = np.sum(y_train) / len(y_train) baseline = np.sum(y_train) / len(y_train)
axes[model_idx][1].plot([0, 1], [baseline, baseline], linestyle='--', lw=2, color='r', alpha=.8, label='Baseline') axes[model_idx][1].plot([0, 1], [baseline, baseline], linestyle='--', lw=2, color='r', alpha=.8, label='Baseline')
# Set Precision-Recall plot limits and title # Set Precision-Recall plot limits and title
axes[model_idx][1].set(xlim=[-0.05, 1.05], ylim=[-0.05, 1.05], title=f"Precision-Recall Curve - {model_name} ({group}-{method_names[j]})") axes[model_idx][1].set(xlim=[-0.05, 1.05], ylim=[-0.05, 1.05], title=f"Precision-Recall Curve - {model_name} ({group}-{method_names[j]})")
axes[model_idx][1].legend(loc="lower right") axes[model_idx][1].legend(loc="lower left", fontsize='small')
axes[model_idx][1].set_aspect('equal') # Set the aspect ratio to be
# Add axis labels
axes[model_idx][1].set_xlabel('Recall')
axes[model_idx][1].set_ylabel('Precision')
# --------------------- END CURVES --------------------------- # --------------------- END CURVES ---------------------------
# Store the fold scores in the dataframe # Store the fold scores in the dataframe
for metric_name, scores in fold_scores.items(): for metric_name, scores in fold_scores.items():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment