Commit bb7d8177 authored by Joaquin Torres's avatar Joaquin Torres

waiting for details to launch, preparing for shap

parent de229b37
...@@ -152,9 +152,7 @@ if __name__ == "__main__": ...@@ -152,9 +152,7 @@ if __name__ == "__main__":
# Store each df as a sheet in an excel file # Store each df as a sheet in an excel file
sheets_dict = {} sheets_dict = {}
for i, group in enumerate(['pre', 'post']): for i, group in enumerate(['pre', 'post']):
print(group, end = ' ')
for j, method in enumerate(['', '', 'over_', 'under_']): for j, method in enumerate(['', '', 'over_', 'under_']):
print(method, end = ' ')
# Get dataset based on group and method # Get dataset based on group and method
X = data_dic['X_train_' + method + group] X = data_dic['X_train_' + method + group]
y = data_dic['y_train_' + method + group] y = data_dic['y_train_' + method + group]
...@@ -163,7 +161,7 @@ if __name__ == "__main__": ...@@ -163,7 +161,7 @@ if __name__ == "__main__":
# Save results: params and best score for each of the mdodels of this method and group # Save results: params and best score for each of the mdodels of this method and group
hyperparam_df = pd.DataFrame(index=list(models.keys()), columns=['Parameters','Score']) hyperparam_df = pd.DataFrame(index=list(models.keys()), columns=['Parameters','Score'])
for model_name, model in models.items(): for model_name, model in models.items():
print(model_name + "\n\n") print(f"{group}-{method}-{model_name} \n\n")
# Find optimal hyperparams for curr model # Find optimal hyperparams for curr model
params = hyperparameters[model_name] params = hyperparameters[model_name]
search = RandomizedSearchCV(model, param_distributions=params, cv=cv, n_jobs=1, scoring='precision') search = RandomizedSearchCV(model, param_distributions=params, cv=cv, n_jobs=1, scoring='precision')
......
# Libraries
# --------------------------------------------------------------------------------------------------------
import shap
import numpy as np
# --------------------------------------------------------------------------------------------------------
# Load test data
X_test_pre = np.load('../gen_train_data/data/output/pre/X_test_pre.npy', allow_pickle=True)
print(list(X_test_pre.columns.values))
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment