Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Submit feedback
Sign in
Toggle navigation
C
covid_analysis
Project overview
Project overview
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
COMPARA
covid_analysis
Commits
bb2e28d5
Commit
bb2e28d5
authored
Jun 10, 2024
by
Joaquin Torres
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Working and ready to generate shap and shap interaction values
parent
2412d533
Changes
2
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
331851 additions
and
26068 deletions
+331851
-26068
explicability/shap_summary_plot.svg
explicability/shap_summary_plot.svg
+331807
-26062
explicability/shap_vals_testing.py
explicability/shap_vals_testing.py
+44
-6
No files found.
explicability/shap_summary_plot.svg
View file @
bb2e28d5
This diff is collapsed.
Click to expand it.
explicability/shap_vals_testing.py
View file @
bb2e28d5
...
...
@@ -117,6 +117,41 @@ def get_chosen_model(group_str, method_str, model_name):
return
chosen_model
,
is_tree
# --------------------------------------------------------------------------------------------------------
# Get balanced subset of n elements from original datasets
# --------------------------------------------------------------------------------------------------------
def
get_sample
(
X_train
,
y_train
,
X_test
,
y_test
,
n
):
# Convert numpy arrays to pandas series for easier handling if necessary
y_train
=
pd
.
Series
(
y_train
)
y_test
=
pd
.
Series
(
y_test
)
# Concatenate X and y for train and test to make it easier to work with
train
=
pd
.
concat
([
X_train
,
y_train
.
rename
(
'target'
)],
axis
=
1
)
test
=
pd
.
concat
([
X_test
,
y_test
.
rename
(
'target'
)],
axis
=
1
)
# Get n/2 samples from each class for the training set
train_0
=
train
[
train
[
'target'
]
==
0
]
.
sample
(
n
//
2
)
train_1
=
train
[
train
[
'target'
]
==
1
]
.
sample
(
n
//
2
)
# Concatenate the samples to form the balanced training set
balanced_train
=
pd
.
concat
([
train_0
,
train_1
])
# Get n/2 samples from each class for the testing set
test_0
=
test
[
test
[
'target'
]
==
0
]
.
sample
(
n
//
2
)
test_1
=
test
[
test
[
'target'
]
==
1
]
.
sample
(
n
//
2
)
# Concatenate the samples to form the balanced testing set
balanced_test
=
pd
.
concat
([
test_0
,
test_1
])
# Separate the features and the target variable for both sets
X_train_balanced
=
balanced_train
.
drop
(
'target'
,
axis
=
1
)
y_train_balanced
=
balanced_train
[
'target'
]
X_test_balanced
=
balanced_test
.
drop
(
'target'
,
axis
=
1
)
y_test_balanced
=
balanced_test
[
'target'
]
return
X_train_balanced
,
y_train_balanced
,
X_test_balanced
,
y_test_balanced
# --------------------------------------------------------------------------------------------------------
if
__name__
==
"__main__"
:
# Setup
...
...
@@ -147,12 +182,15 @@ if __name__ == "__main__":
y_test
=
data_dic
[
'y_test_'
+
group
]
X_train
=
data_dic
[
'X_train_'
+
method
+
group
]
y_train
=
data_dic
[
'y_train_'
+
method
+
group
]
X_train
,
y_train
,
X_test
,
y_test
=
get_sample
(
X_train
,
y_train
,
X_test
,
y_test
,
500
)
method_name
=
'UNDER'
# Get chosen tuned model for this group and method context
model
,
is_tree
=
get_chosen_model
(
group_str
=
group
,
method_str
=
method_name
,
model_name
=
model_choices
[
method_name
])
fit_start_t
=
time
.
time
()
# Fit model with training data
fitted_model
=
model
.
fit
(
X_train
[:
500
],
y_train
[:
500
]
)
fitted_model
=
model
.
fit
(
X_train
,
y_train
)
fit_end_t
=
time
.
time
()
print
(
f
'Fitted OK. Took {fit_end_t-fit_start_t} seconds.'
)
# Check if we are dealing with a tree vs nn model
...
...
@@ -164,18 +202,18 @@ if __name__ == "__main__":
shap_start_t
=
time
.
time
()
# Compute shap values
shap_val_start_t
=
time
.
time
()
shap_vals
=
explainer
.
shap_values
(
X_test
[:
500
]
,
check_additivity
=
False
)
# Change to true for final results
shap_vals
=
explainer
.
shap_values
(
X_test
,
check_additivity
=
False
)
# Change to true for final results
shap_val_end_t
=
time
.
time
()
print
(
f
'Shap values computed. Took {shap_val_end_t-shap_val_start_t} seconds.'
)
# Compute shap interaction values
shap_interaction_values
=
explainer
.
shap_interaction_values
(
X_test
[:
500
]
)
print
(
f
'Shape: {shap_interaction_values.shape}'
)
shap_interaction_values
=
explainer
.
shap_interaction_values
(
X_test
)
print
(
f
'Shape
Interaction Values
: {shap_interaction_values.shape}'
)
shap_end_t
=
time
.
time
()
print
(
f
'Interaction values computed. Took {shap_end_t - shap_start_t} seconds.'
)
# Plot interaction values accross variables
plot_start_t
=
time
.
time
()
shap
.
summary_plot
(
shap_interaction_values
,
X_test
[:
500
],
max_display
=
5
)
shap
.
summary_plot
(
shap_interaction_values
,
X_test
,
max_display
=
39
)
plot_end_t
=
time
.
time
()
print
(
f
'Plot done. Took {plot_end_t - plot_start_t} seconds.'
)
plt
.
savefig
(
'shap_summary_plot.svg'
,
dpi
=
1
000
)
plt
.
savefig
(
'shap_summary_plot.svg'
,
dpi
=
2
000
)
plt
.
close
()
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment