diff --git a/model_selection/test_models.py b/model_selection/test_models.py
index 7e2f94e1771dc33e3fb50706016039cd984b4bca..120fd522aa75e2b15b5626765738186d8a0923b1 100644
--- a/model_selection/test_models.py
+++ b/model_selection/test_models.py
@@ -14,6 +14,9 @@ from sklearn.neural_network import MLPClassifier
from sklearn.svm import SVC
from sklearn.linear_model import LogisticRegression
from sklearn.tree import DecisionTreeClassifier
+from sklearn.metrics import RocCurveDisplay, roc_curve
+from sklearn.metrics import PrecisionRecallDisplay, precision_recall_curve
+import matplotlib.pyplot as plt
# --------------------------------------------------------------------------------------------------------
# Reading test data
@@ -246,16 +249,36 @@ if __name__ == "__main__":
models = get_tuned_models(group_id=i, method_id=j)
# Scores df
scores_df = pd.DataFrame(index=models.keys(), columns=scorings.keys())
+ # Create a figure for all models in this group-method
+ fig, axes = plt.subplots(len(models), 2, figsize=(12, 8 * len(models)))
+ if len(models) == 1: # Adjustment if there's only one model (axes indexing issue)
+ axes = [axes]
# Evaluate each model
- for model_name, model in models.items():
+ for model_idx, (model_name, model) in enumerate(models.items()):
# ----------- TEMPORAL -------------
- if model_name == "DT":
+ if model_name == "DT" or model_name == "RF":
# Train the model (it was just initialized above)
- model.fit(X_train, y_train)
+ model.fit(X_train, y_train)
+ if hasattr(model, "decision_function"):
+ y_score = model.decision_function(X_test)
+ else:
+ y_score = model.predict_proba(X_test)[:, 1] # Use probability of positive class
+ # Calculate ROC curve and ROC area for each class
+ fpr, tpr, _ = roc_curve(y_test, y_score, pos_label=model.classes_[1])
+ roc_display = RocCurveDisplay(fpr=fpr, tpr=tpr).plot(ax=axes[model_idx][0])
+ # Calculate precision-recall curve
+ precision, recall, _ = precision_recall_curve(y_test, y_score, pos_label=model.classes_[1])
+ pr_display = PrecisionRecallDisplay(precision=precision, recall=recall).plot(ax=axes[model_idx][1])
+ axes[model_idx][0].set_title(f'ROC Curve for {model_name}')
+ axes[model_idx][1].set_title(f'PR Curve for {model_name}')
# Evaluate at each of the scores of interest
for score_name, scorer in scorings.items():
score_value = scorer(model, X_test, y_test)
scores_df.at[model_name, score_name] = score_value
+ # Adjust layout and save/show figure
+ plt.tight_layout()
+ plt.savefig(f'./test_results/roc_pr_curves/{group}_{method}.svg', format='svg', dpi=500)
+ plt.close(fig)
# Store the DataFrame in the dictionary with a unique key for each sheet
sheet_name = f"{group}_{method_names[j]}"
scores_sheets[sheet_name] = scores_df
diff --git a/model_selection/test_results/roc_pr_curves/post_.svg b/model_selection/test_results/roc_pr_curves/post_.svg
new file mode 100644
index 0000000000000000000000000000000000000000..0c7cf3c398f31f84bb9d0e8aef7cbafd01db908f
--- /dev/null
+++ b/model_selection/test_results/roc_pr_curves/post_.svg
@@ -0,0 +1,3678 @@
+
+
+
diff --git a/model_selection/test_results/roc_pr_curves/post_over_.svg b/model_selection/test_results/roc_pr_curves/post_over_.svg
new file mode 100644
index 0000000000000000000000000000000000000000..c28586cde6cfaec7924053c7ac1e2dfd896baf3c
--- /dev/null
+++ b/model_selection/test_results/roc_pr_curves/post_over_.svg
@@ -0,0 +1,4539 @@
+
+
+
diff --git a/model_selection/test_results/roc_pr_curves/post_under_.svg b/model_selection/test_results/roc_pr_curves/post_under_.svg
new file mode 100644
index 0000000000000000000000000000000000000000..ca7801aa0673619900060182ed684790d99f4781
--- /dev/null
+++ b/model_selection/test_results/roc_pr_curves/post_under_.svg
@@ -0,0 +1,4568 @@
+
+
+
diff --git a/model_selection/test_results/roc_pr_curves/pre_.svg b/model_selection/test_results/roc_pr_curves/pre_.svg
new file mode 100644
index 0000000000000000000000000000000000000000..9bd255c02db78cbc65ce4cbd21209e2b74a104ae
--- /dev/null
+++ b/model_selection/test_results/roc_pr_curves/pre_.svg
@@ -0,0 +1,3768 @@
+
+
+
diff --git a/model_selection/test_results/roc_pr_curves/pre_over_.svg b/model_selection/test_results/roc_pr_curves/pre_over_.svg
new file mode 100644
index 0000000000000000000000000000000000000000..8a2895f0718bb35913810782fd7a56b2d9ab5a35
--- /dev/null
+++ b/model_selection/test_results/roc_pr_curves/pre_over_.svg
@@ -0,0 +1,4526 @@
+
+
+
diff --git a/model_selection/test_results/roc_pr_curves/pre_under_.svg b/model_selection/test_results/roc_pr_curves/pre_under_.svg
new file mode 100644
index 0000000000000000000000000000000000000000..1ca019563b6d3a850a2712b72f077e12d06f66d6
--- /dev/null
+++ b/model_selection/test_results/roc_pr_curves/pre_under_.svg
@@ -0,0 +1,4544 @@
+
+
+
diff --git a/model_selection/test_results/testing_tuned_models.xlsx b/model_selection/test_results/testing_tuned_models.xlsx
index 1c0daa5be47052ff8dde569780d264ef84052403..df38c76852ae83412fcacee1cf290c0c8e98f248 100644
Binary files a/model_selection/test_results/testing_tuned_models.xlsx and b/model_selection/test_results/testing_tuned_models.xlsx differ