diff --git a/explicability/shap_vals.py b/explicability/shap_vals.py index 3424f2952ebe200562eda20e5126193778815534..d498cdeceb0103c3001447ccec0dc624086e8ce8 100644 --- a/explicability/shap_vals.py +++ b/explicability/shap_vals.py @@ -143,7 +143,7 @@ if __name__ == "__main__": # Shap value generation # -------------------------------------------------------------------------------------------------------- for i, group in enumerate(['pre', 'post']): - # Get test dataset based on group + # Get test dataset based on group, add column names X_test = pd.DataFrame(data_dic['X_test_' + group], columns=attribute_names) y_test = data_dic['y_test_' + group] for j, method in enumerate(['', '', 'over_', 'under_']): @@ -157,14 +157,15 @@ if __name__ == "__main__": # -------------------------------------------------------------------------------------------------------- # Fit model with training data fitted_model = model.fit(X_train[:500], y_train[:500]) - # Check if we are dealing with a tree vs nn model + # # Check if we are dealing with a tree vs nn model if is_tree: - explainer = shap.TreeExplainer(fitted_model, X_test[:500]) - else: - explainer = shap.KernelExplainer(fitted_model.predict_proba, X_test[:500]) + explainer = shap.TreeExplainer(fitted_model) + # else: + # explainer = shap.KernelExplainer(fitted_model.predict_proba, X_test[:500]) # Compute shap values shap_vals = explainer.shap_values(X_test[:500], check_additivity=False) # Change to true for final results # --------------------------------------------------------------------------------------------------------- # Save results np.save(f"./output/shap_values/{group}_{method_names[j]}", shap_vals) + print(f'Shape of numpy array: {shap_vals.shape}') # -------------------------------------------------------------------------------------------------------- \ No newline at end of file