Commit 3b07abab authored by Joaquin Torres's avatar Joaquin Torres

script ready for shap vals computation

parent 36a3534e
......@@ -86,11 +86,13 @@ if __name__ == "__main__":
"post_OVER": (None,None),
"post_UNDER": (None,None),
}
# # Retrieve attribute names in order
# df = pd.read_csv("..\gen_train_data\data\input\pre_dataset.csv")
# attribute_names = list(df.columns.values)
# --------------------------------------------------------------------------------------------------------
# Shap value generation
# --------------------------------------------------------------------------------------------------------
shap_values = {} # Mapping group-method -> shap values
for i, group in enumerate(['pre', 'post']):
# Get test dataset based on group
X_test = data_dic['X_test_' + group]
......@@ -105,8 +107,14 @@ if __name__ == "__main__":
is_tree = model_info[0]
model = model_info[1]
# Fit model with training data
fitted_model = model.fit(X_train, y_train) # [:500]?
fitted_model = model.fit(X_train[:500], y_train[:500])
# Check if we are dealing with a tree vs nn model
if is_tree:
explainer = shap.TreeExplainer(fitted_model, X_test) # [:500]?
explainer = shap.TreeExplainer(fitted_model, X_test[:500])
else:
explainer = shap.KernelExplainer(fitted_model.predict, X_test[:500])
# Compute shap values
shap_vals = explainer.shap_values(X_test[:500], check_additivity=False) # Change to true for final results
# Save results
np.save(f"shap_values/{group}_{method_names[j]}", shap_vals)
# --------------------------------------------------------------------------------------------------------
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment