From 620b8a59e277328788889159a2aab564585f9f4c Mon Sep 17 00:00:00 2001 From: joaquintb Date: Wed, 22 May 2024 11:19:13 +0200 Subject: [PATCH] Corrected hyperparam_tuning metric computation --- model_selection/hyperparam_tuning.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/model_selection/hyperparam_tuning.py b/model_selection/hyperparam_tuning.py index fc97f49..6d8af49 100644 --- a/model_selection/hyperparam_tuning.py +++ b/model_selection/hyperparam_tuning.py @@ -73,13 +73,13 @@ if __name__ == "__main__": # -------------------------------------------------------------------------------------------------------- # 1. No class weight models_simple = {"DT" : DecisionTreeClassifier(), - # "RF" : RandomForestClassifier(), - # "Bagging" : BaggingClassifier(), - # "AB" : AdaBoostClassifier(algorithm='SAMME'), - # "XGB": XGBClassifier(), - # "LR" : LogisticRegression(max_iter=1000), - # "SVM" : SVC(probability=True, max_iter=1000), - # "MLP" : MLPClassifier(max_iter=500) + "RF" : RandomForestClassifier(), + "Bagging" : BaggingClassifier(), + "AB" : AdaBoostClassifier(algorithm='SAMME'), + "XGB": XGBClassifier(), + "LR" : LogisticRegression(max_iter=1000), + "SVM" : SVC(probability=True, max_iter=1000), + "MLP" : MLPClassifier(max_iter=500) # "ElNet" : LogisticRegression(max_iter=1000, penalty='elasticnet') } @@ -141,8 +141,8 @@ if __name__ == "__main__": # -------------------------------------------------------------------------------------------------------- # Store each df as a sheet in an excel file sheets_dict = {} - for i, group in enumerate(['pre']): #['pre', 'post'] - for j, method in enumerate(['under_']): #['', '', 'over_', 'under_'] + for i, group in enumerate(['pre', 'post']): + for j, method in enumerate(['', '', 'over_', 'under_']): # Get dataset based on group and method X = data_dic['X_train_' + method + group] y = data_dic['y_train_' + method + group] -- 2.24.1