Commit bb2e28d5 authored by Joaquin Torres's avatar Joaquin Torres

Working and ready to generate shap and shap interaction values

parent 2412d533
This source diff could not be displayed because it is too large. You can view the blob instead.
......@@ -117,6 +117,41 @@ def get_chosen_model(group_str, method_str, model_name):
return chosen_model, is_tree
# --------------------------------------------------------------------------------------------------------
# Get balanced subset of n elements from original datasets
# --------------------------------------------------------------------------------------------------------
def get_sample(X_train, y_train, X_test, y_test, n):
# Convert numpy arrays to pandas series for easier handling if necessary
y_train = pd.Series(y_train)
y_test = pd.Series(y_test)
# Concatenate X and y for train and test to make it easier to work with
train = pd.concat([X_train, y_train.rename('target')], axis=1)
test = pd.concat([X_test, y_test.rename('target')], axis=1)
# Get n/2 samples from each class for the training set
train_0 = train[train['target'] == 0].sample(n//2)
train_1 = train[train['target'] == 1].sample(n//2)
# Concatenate the samples to form the balanced training set
balanced_train = pd.concat([train_0, train_1])
# Get n/2 samples from each class for the testing set
test_0 = test[test['target'] == 0].sample(n//2)
test_1 = test[test['target'] == 1].sample(n//2)
# Concatenate the samples to form the balanced testing set
balanced_test = pd.concat([test_0, test_1])
# Separate the features and the target variable for both sets
X_train_balanced = balanced_train.drop('target', axis=1)
y_train_balanced = balanced_train['target']
X_test_balanced = balanced_test.drop('target', axis=1)
y_test_balanced = balanced_test['target']
return X_train_balanced, y_train_balanced, X_test_balanced, y_test_balanced
# --------------------------------------------------------------------------------------------------------
if __name__ == "__main__":
# Setup
......@@ -147,12 +182,15 @@ if __name__ == "__main__":
y_test = data_dic['y_test_' + group]
X_train = data_dic['X_train_' + method + group]
y_train = data_dic['y_train_' + method + group]
X_train, y_train, X_test, y_test = get_sample(X_train, y_train, X_test, y_test, 500)
method_name = 'UNDER'
# Get chosen tuned model for this group and method context
model, is_tree = get_chosen_model(group_str=group, method_str=method_name, model_name=model_choices[method_name])
fit_start_t = time.time()
# Fit model with training data
fitted_model = model.fit(X_train[:500], y_train[:500])
fitted_model = model.fit(X_train, y_train)
fit_end_t = time.time()
print(f'Fitted OK. Took {fit_end_t-fit_start_t} seconds.')
# Check if we are dealing with a tree vs nn model
......@@ -164,18 +202,18 @@ if __name__ == "__main__":
shap_start_t = time.time()
# Compute shap values
shap_val_start_t = time.time()
shap_vals = explainer.shap_values(X_test[:500], check_additivity=False) # Change to true for final results
shap_vals = explainer.shap_values(X_test, check_additivity=False) # Change to true for final results
shap_val_end_t = time.time()
print(f'Shap values computed. Took {shap_val_end_t-shap_val_start_t} seconds.')
# Compute shap interaction values
shap_interaction_values = explainer.shap_interaction_values(X_test[:500])
print(f'Shape: {shap_interaction_values.shape}')
shap_interaction_values = explainer.shap_interaction_values(X_test)
print(f'Shape Interaction Values: {shap_interaction_values.shape}')
shap_end_t = time.time()
print(f'Interaction values computed. Took {shap_end_t - shap_start_t} seconds.')
# Plot interaction values accross variables
plot_start_t = time.time()
shap.summary_plot(shap_interaction_values, X_test[:500], max_display=5)
shap.summary_plot(shap_interaction_values, X_test, max_display=39)
plot_end_t = time.time()
print(f'Plot done. Took {plot_end_t - plot_start_t} seconds.')
plt.savefig('shap_summary_plot.svg', dpi=1000)
plt.savefig('shap_summary_plot.svg', dpi=2000)
plt.close()
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment