Commit e45e7474 authored by Joaquin Torres's avatar Joaquin Torres

testing script for DT and RF

parent ebe6d11f
...@@ -14,6 +14,9 @@ from sklearn.neural_network import MLPClassifier ...@@ -14,6 +14,9 @@ from sklearn.neural_network import MLPClassifier
from sklearn.svm import SVC from sklearn.svm import SVC
from sklearn.linear_model import LogisticRegression from sklearn.linear_model import LogisticRegression
from sklearn.tree import DecisionTreeClassifier from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import RocCurveDisplay, roc_curve
from sklearn.metrics import PrecisionRecallDisplay, precision_recall_curve
import matplotlib.pyplot as plt
# -------------------------------------------------------------------------------------------------------- # --------------------------------------------------------------------------------------------------------
# Reading test data # Reading test data
...@@ -246,16 +249,36 @@ if __name__ == "__main__": ...@@ -246,16 +249,36 @@ if __name__ == "__main__":
models = get_tuned_models(group_id=i, method_id=j) models = get_tuned_models(group_id=i, method_id=j)
# Scores df # Scores df
scores_df = pd.DataFrame(index=models.keys(), columns=scorings.keys()) scores_df = pd.DataFrame(index=models.keys(), columns=scorings.keys())
# Create a figure for all models in this group-method
fig, axes = plt.subplots(len(models), 2, figsize=(12, 8 * len(models)))
if len(models) == 1: # Adjustment if there's only one model (axes indexing issue)
axes = [axes]
# Evaluate each model # Evaluate each model
for model_name, model in models.items(): for model_idx, (model_name, model) in enumerate(models.items()):
# ----------- TEMPORAL ------------- # ----------- TEMPORAL -------------
if model_name == "DT": if model_name == "DT" or model_name == "RF":
# Train the model (it was just initialized above) # Train the model (it was just initialized above)
model.fit(X_train, y_train) model.fit(X_train, y_train)
if hasattr(model, "decision_function"):
y_score = model.decision_function(X_test)
else:
y_score = model.predict_proba(X_test)[:, 1] # Use probability of positive class
# Calculate ROC curve and ROC area for each class
fpr, tpr, _ = roc_curve(y_test, y_score, pos_label=model.classes_[1])
roc_display = RocCurveDisplay(fpr=fpr, tpr=tpr).plot(ax=axes[model_idx][0])
# Calculate precision-recall curve
precision, recall, _ = precision_recall_curve(y_test, y_score, pos_label=model.classes_[1])
pr_display = PrecisionRecallDisplay(precision=precision, recall=recall).plot(ax=axes[model_idx][1])
axes[model_idx][0].set_title(f'ROC Curve for {model_name}')
axes[model_idx][1].set_title(f'PR Curve for {model_name}')
# Evaluate at each of the scores of interest # Evaluate at each of the scores of interest
for score_name, scorer in scorings.items(): for score_name, scorer in scorings.items():
score_value = scorer(model, X_test, y_test) score_value = scorer(model, X_test, y_test)
scores_df.at[model_name, score_name] = score_value scores_df.at[model_name, score_name] = score_value
# Adjust layout and save/show figure
plt.tight_layout()
plt.savefig(f'./test_results/roc_pr_curves/{group}_{method}.svg', format='svg', dpi=500)
plt.close(fig)
# Store the DataFrame in the dictionary with a unique key for each sheet # Store the DataFrame in the dictionary with a unique key for each sheet
sheet_name = f"{group}_{method_names[j]}" sheet_name = f"{group}_{method_names[j]}"
scores_sheets[sheet_name] = scores_df scores_sheets[sheet_name] = scores_df
......
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment