Commit f919e066 authored by Joaquin Torres's avatar Joaquin Torres

Identified problem with features

parent 7b58b74c
......@@ -143,7 +143,7 @@ if __name__ == "__main__":
# Shap value generation
# --------------------------------------------------------------------------------------------------------
for i, group in enumerate(['pre', 'post']):
# Get test dataset based on group
# Get test dataset based on group, add column names
X_test = pd.DataFrame(data_dic['X_test_' + group], columns=attribute_names)
y_test = data_dic['y_test_' + group]
for j, method in enumerate(['', '', 'over_', 'under_']):
......@@ -157,14 +157,15 @@ if __name__ == "__main__":
# --------------------------------------------------------------------------------------------------------
# Fit model with training data
fitted_model = model.fit(X_train[:500], y_train[:500])
# Check if we are dealing with a tree vs nn model
# # Check if we are dealing with a tree vs nn model
if is_tree:
explainer = shap.TreeExplainer(fitted_model, X_test[:500])
else:
explainer = shap.KernelExplainer(fitted_model.predict_proba, X_test[:500])
explainer = shap.TreeExplainer(fitted_model)
# else:
# explainer = shap.KernelExplainer(fitted_model.predict_proba, X_test[:500])
# Compute shap values
shap_vals = explainer.shap_values(X_test[:500], check_additivity=False) # Change to true for final results
# ---------------------------------------------------------------------------------------------------------
# Save results
np.save(f"./output/shap_values/{group}_{method_names[j]}", shap_vals)
print(f'Shape of numpy array: {shap_vals.shape}')
# --------------------------------------------------------------------------------------------------------
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment