From 7b58b74ce7f1db9ce19277f582ef7b6057863441 Mon Sep 17 00:00:00 2001 From: Joaquin Torres Bravo Date: Thu, 6 Jun 2024 16:40:17 +0200 Subject: [PATCH] Progress made on shap (still pending to see predict_proba and X_train vs test) --- explicability/shap_vals.py | 46 +++++++++++++++++--------------------- 1 file changed, 21 insertions(+), 25 deletions(-) diff --git a/explicability/shap_vals.py b/explicability/shap_vals.py index e5b6e89..3424f29 100644 --- a/explicability/shap_vals.py +++ b/explicability/shap_vals.py @@ -110,8 +110,10 @@ def get_chosen_model(group_str, method_str, model_name): # Initialize the model with the parameters chosen_model = model_class(**parameters) + # Return if it is a tree model, for SHAP + is_tree = model_name not in ['LR', 'SVM', 'MLP'] - return chosen_model + return chosen_model, is_tree # -------------------------------------------------------------------------------------------------------- if __name__ == "__main__": @@ -133,42 +135,36 @@ if __name__ == "__main__": "OVER": "XGB", "UNDER": "XGB" } - # # Retrieve attribute names in order - # df = pd.read_csv("..\gen_train_data\data\input\pre_dataset.csv") - # attribute_names = list(df.columns.values) + # Retrieve attribute names in order + df = pd.read_csv("../gen_train_data/data/input/pre_dataset.csv") + attribute_names = list(df.columns.values) # -------------------------------------------------------------------------------------------------------- # Shap value generation # -------------------------------------------------------------------------------------------------------- for i, group in enumerate(['pre', 'post']): # Get test dataset based on group - X_test = data_dic['X_test_' + group] + X_test = pd.DataFrame(data_dic['X_test_' + group], columns=attribute_names) y_test = data_dic['y_test_' + group] for j, method in enumerate(['', '', 'over_', 'under_']): print(f"{group}-{method_names[j]}") # Get train dataset based on group and method - X_train = data_dic['X_train_' + method + group] + X_train = pd.DataFrame(data_dic['X_train_' + method + group], columns=attribute_names) y_train = data_dic['y_train_' + method + group] method_name = method_names[j] # Get chosen tuned model for this group and method context - model = get_chosen_model(group_str=group, method_str=method_name, model_name=model_choices[method_name]) - print(f'Name: {model_choices[method_name]}') - print(model.get_params()) - # # -------------------------------------------------------------------------------------------------------- - # # Retrieve best model for this group-method context - # model_info = models[group + '_' + method_names[j]] - # is_tree = model_info[0] - # model = model_info[1] - # # Fit model with training data - # fitted_model = model.fit(X_train[:500], y_train[:500]) - # # Check if we are dealing with a tree vs nn model - # if is_tree: - # explainer = shap.TreeExplainer(fitted_model, X_test[:500]) - # else: - # explainer = shap.KernelExplainer(fitted_model.predict, X_test[:500]) - # # Compute shap values - # shap_vals = explainer.shap_values(X_test[:500], check_additivity=False) # Change to true for final results - # # --------------------------------------------------------------------------------------------------------- + model, is_tree = get_chosen_model(group_str=group, method_str=method_name, model_name=model_choices[method_name]) + # -------------------------------------------------------------------------------------------------------- + # Fit model with training data + fitted_model = model.fit(X_train[:500], y_train[:500]) + # Check if we are dealing with a tree vs nn model + if is_tree: + explainer = shap.TreeExplainer(fitted_model, X_test[:500]) + else: + explainer = shap.KernelExplainer(fitted_model.predict_proba, X_test[:500]) + # Compute shap values + shap_vals = explainer.shap_values(X_test[:500], check_additivity=False) # Change to true for final results + # --------------------------------------------------------------------------------------------------------- # Save results - # np.save(f"shap_values/{group}_{method_names[j]}", shap_vals) + np.save(f"./output/shap_values/{group}_{method_names[j]}", shap_vals) # -------------------------------------------------------------------------------------------------------- \ No newline at end of file -- 2.24.1