From f919e0668cda6ae148baf927b080eff0ab2fad10 Mon Sep 17 00:00:00 2001 From: Joaquin Torres Bravo Date: Fri, 7 Jun 2024 11:50:19 +0200 Subject: [PATCH] Identified problem with features --- explicability/shap_vals.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/explicability/shap_vals.py b/explicability/shap_vals.py index 3424f29..d498cde 100644 --- a/explicability/shap_vals.py +++ b/explicability/shap_vals.py @@ -143,7 +143,7 @@ if __name__ == "__main__": # Shap value generation # -------------------------------------------------------------------------------------------------------- for i, group in enumerate(['pre', 'post']): - # Get test dataset based on group + # Get test dataset based on group, add column names X_test = pd.DataFrame(data_dic['X_test_' + group], columns=attribute_names) y_test = data_dic['y_test_' + group] for j, method in enumerate(['', '', 'over_', 'under_']): @@ -157,14 +157,15 @@ if __name__ == "__main__": # -------------------------------------------------------------------------------------------------------- # Fit model with training data fitted_model = model.fit(X_train[:500], y_train[:500]) - # Check if we are dealing with a tree vs nn model + # # Check if we are dealing with a tree vs nn model if is_tree: - explainer = shap.TreeExplainer(fitted_model, X_test[:500]) - else: - explainer = shap.KernelExplainer(fitted_model.predict_proba, X_test[:500]) + explainer = shap.TreeExplainer(fitted_model) + # else: + # explainer = shap.KernelExplainer(fitted_model.predict_proba, X_test[:500]) # Compute shap values shap_vals = explainer.shap_values(X_test[:500], check_additivity=False) # Change to true for final results # --------------------------------------------------------------------------------------------------------- # Save results np.save(f"./output/shap_values/{group}_{method_names[j]}", shap_vals) + print(f'Shape of numpy array: {shap_vals.shape}') # -------------------------------------------------------------------------------------------------------- \ No newline at end of file -- 2.24.1