diff --git a/explicability/shap_vals.py b/explicability/shap_vals.py index 8511020acfe54b88c37222063f741f010a1ab558..dc3cba557094985044338c14343fe7da4c02e57a 100644 --- a/explicability/shap_vals.py +++ b/explicability/shap_vals.py @@ -86,11 +86,13 @@ if __name__ == "__main__": "post_OVER": (None,None), "post_UNDER": (None,None), } + # # Retrieve attribute names in order + # df = pd.read_csv("..\gen_train_data\data\input\pre_dataset.csv") + # attribute_names = list(df.columns.values) # -------------------------------------------------------------------------------------------------------- # Shap value generation # -------------------------------------------------------------------------------------------------------- - shap_values = {} # Mapping group-method -> shap values for i, group in enumerate(['pre', 'post']): # Get test dataset based on group X_test = data_dic['X_test_' + group] @@ -105,8 +107,14 @@ if __name__ == "__main__": is_tree = model_info[0] model = model_info[1] # Fit model with training data - fitted_model = model.fit(X_train, y_train) # [:500]? + fitted_model = model.fit(X_train[:500], y_train[:500]) # Check if we are dealing with a tree vs nn model if is_tree: - explainer = shap.TreeExplainer(fitted_model, X_test) # [:500]? + explainer = shap.TreeExplainer(fitted_model, X_test[:500]) + else: + explainer = shap.KernelExplainer(fitted_model.predict, X_test[:500]) + # Compute shap values + shap_vals = explainer.shap_values(X_test[:500], check_additivity=False) # Change to true for final results + # Save results + np.save(f"shap_values/{group}_{method_names[j]}", shap_vals) # -------------------------------------------------------------------------------------------------------- \ No newline at end of file