Commit 7027a221 authored by Joaquin Torres's avatar Joaquin Torres

Testing with PRE-ORIG

parent 206e5f64
......@@ -206,43 +206,42 @@ if __name__ == "__main__":
# Metric generation for each model
for model_idx, (model_name, model) in enumerate(models.items()):
print(f"{group}-{method_names[j]}-{model_name}")
if model_name == 'DT':
# Curve generation setup
mean_fpr = np.linspace(0, 1, 100)
tprs, aucs = [], []
mean_recall = np.linspace(0, 1, 100)
precisions, pr_aucs = [], []
cmap = plt.get_cmap('tab10') # Colormap
# Initialize storage for scores for each fold
fold_scores = {metric_name: [] for metric_name in scorings.keys()}
# Manually loop through each fold in the cross-validation
for fold_idx, (train_idx, test_idx) in enumerate(cv.split(X_train, y_train)):
X_train_fold, X_test_fold = X_train[train_idx], X_train[test_idx]
y_train_fold, y_test_fold = y_train[train_idx], y_train[test_idx]
# Fit the model on the training data
model.fit(X_train_fold, y_train_fold)
# --------------------- SCORINGS ---------------------------
# Calculate and store the scores for each metric
for metric_name, scorer in scorings.items():
score = scorer(model, X_test_fold, y_test_fold)
fold_scores[metric_name].append(score)
# --------------------- END SCORINGS ---------------------------
# Curve generation setup
mean_fpr = np.linspace(0, 1, 100)
tprs, aucs = [], []
mean_recall = np.linspace(0, 1, 100)
precisions, pr_aucs = [], []
cmap = plt.get_cmap('tab10') # Colormap
# Initialize storage for scores for each fold
fold_scores = {metric_name: [] for metric_name in scorings.keys()}
# Manually loop through each fold in the cross-validation
for fold_idx, (train_idx, test_idx) in enumerate(cv.split(X_train, y_train)):
X_train_fold, X_test_fold = X_train[train_idx], X_train[test_idx]
y_train_fold, y_test_fold = y_train[train_idx], y_train[test_idx]
# Fit the model on the training data
model.fit(X_train_fold, y_train_fold)
# --------------------- SCORINGS ---------------------------
# Calculate and store the scores for each metric
for metric_name, scorer in scorings.items():
score = scorer(model, X_test_fold, y_test_fold)
fold_scores[metric_name].append(score)
# --------------------- END SCORINGS ---------------------------
# --------------------- CURVES ---------------------------
# Generate ROC curve for the fold
roc_display = RocCurveDisplay.from_estimator(model, X_test_fold, y_test_fold,
name=f"ROC fold {fold_idx}", alpha=0.6, lw=2,
ax=axes[model_idx][0], color=cmap(fold_idx % 10))
interp_tpr = np.interp(mean_fpr, roc_display.fpr, roc_display.tpr)
interp_tpr[0] = 0.0
tprs.append(interp_tpr)
aucs.append(roc_display.roc_auc)
# Generate Precision-Recall curve for the fold
pr_display = PrecisionRecallDisplay.from_estimator(model, X_test_fold, y_test_fold,
name=f"PR fold {fold_idx}", alpha=0.6, lw=2,
ax=axes[model_idx][1], color=cmap(fold_idx % 10))
interp_precision = np.interp(mean_recall, pr_display.recall[::-1], pr_display.precision[::-1])
precisions.append(interp_precision)
pr_aucs.append(pr_display.average_precision)
# Generate ROC curve for the fold
roc_display = RocCurveDisplay.from_estimator(model, X_test_fold, y_test_fold,
name=f"ROC fold {fold_idx}", alpha=0.6, lw=2,
ax=axes[model_idx][0], color=cmap(fold_idx % 10))
interp_tpr = np.interp(mean_fpr, roc_display.fpr, roc_display.tpr)
interp_tpr[0] = 0.0
tprs.append(interp_tpr)
aucs.append(roc_display.roc_auc)
# Generate Precision-Recall curve for the fold
pr_display = PrecisionRecallDisplay.from_estimator(model, X_test_fold, y_test_fold,
name=f"PR fold {fold_idx}", alpha=0.6, lw=2,
ax=axes[model_idx][1], color=cmap(fold_idx % 10))
interp_precision = np.interp(mean_recall, pr_display.recall[::-1], pr_display.precision[::-1])
precisions.append(interp_precision)
pr_aucs.append(pr_display.average_precision)
# Plot diagonal line for random guessing in ROC curve
axes[model_idx][0].plot([0, 1], [0, 1], linestyle='--', lw=2, color='r', alpha=.8, label='Random guessing')
# Compute mean ROC curve
......
This source diff could not be displayed because it is too large. You can view the blob instead.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment