Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Submit feedback
Sign in
Toggle navigation
C
covid_analysis
Project overview
Project overview
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
COMPARA
covid_analysis
Commits
746831df
Commit
746831df
authored
Jun 17, 2024
by
Joaquin Torres
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Ready to generate SHAP plots
parent
866332c3
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
58 additions
and
3 deletions
+58
-3
explicability/compute_shap_inter_vals.py
explicability/compute_shap_inter_vals.py
+4
-3
explicability/shap_plots.py
explicability/shap_plots.py
+54
-0
No files found.
explicability/compute_shap_inter_vals.py
View file @
746831df
...
...
@@ -60,9 +60,10 @@ if __name__ == "__main__":
X_test
=
data_dic
[
'X_test_'
+
group
]
y_test
=
data_dic
[
'y_test_'
+
group
]
for
j
,
method
in
enumerate
([
''
,
''
,
'over_'
,
'under_'
]):
if
j
!=
1
:
print
(
'Skip'
)
continue
# Remove (used to isolate RF)
# if j != 1:
# print('Skip')
# continue
print
(
f
"{group}-{method_names[j]}"
)
method_name
=
method_names
[
j
]
model_name
=
model_choices
[
method_name
]
...
...
explicability/shap_plots.py
0 → 100644
View file @
746831df
# Libraries
# --------------------------------------------------------------------------------------------------------
import
pandas
as
pd
import
numpy
as
np
import
shap
# --------------------------------------------------------------------------------------------------------
# Reading test data
# --------------------------------------------------------------------------------------------------------
def
read_test_data
(
attribute_names
):
# Load test data
X_test_pre
=
np
.
load
(
'../gen_train_data/data/output/pre/X_test_pre.npy'
,
allow_pickle
=
True
)
y_test_pre
=
np
.
load
(
'../gen_train_data/data/output/pre/y_test_pre.npy'
,
allow_pickle
=
True
)
X_test_post
=
np
.
load
(
'../gen_train_data/data/output/post/X_test_post.npy'
,
allow_pickle
=
True
)
y_test_post
=
np
.
load
(
'../gen_train_data/data/output/post/y_test_post.npy'
,
allow_pickle
=
True
)
# Type conversion needed
data_dic
=
{
"X_test_pre"
:
pd
.
DataFrame
(
X_test_pre
,
columns
=
attribute_names
)
.
convert_dtypes
(),
"y_test_pre"
:
y_test_pre
,
"X_test_post"
:
pd
.
DataFrame
(
X_test_post
,
columns
=
attribute_names
)
.
convert_dtypes
(),
"y_test_post"
:
y_test_post
,
}
return
data_dic
# --------------------------------------------------------------------------------------------------------
if
__name__
==
"__main__"
:
# Setup
# --------------------------------------------------------------------------------------------------------
# Retrieve attribute names in order
attribute_names
=
list
(
np
.
load
(
'../gen_train_data/data/output/attributes.npy'
,
allow_pickle
=
True
))
# Reading data
data_dic
=
read_test_data
(
attribute_names
)
method_names
=
{
0
:
"ORIG"
,
1
:
"ORIG_CW"
,
2
:
"OVER"
,
3
:
"UNDER"
}
# --------------------------------------------------------------------------------------------------------
# Plot generation
# --------------------------------------------------------------------------------------------------------
for
i
,
group
in
enumerate
([
'pre'
,
'post'
]):
# Get test dataset based on group, add column names
X_test
=
data_dic
[
'X_test_'
+
group
]
y_test
=
data_dic
[
'y_test_'
+
group
]
for
j
,
method
in
enumerate
([
''
,
''
,
'over_'
,
'under_'
]):
print
(
f
"{group}-{method_names[j]}"
)
method_name
=
method_names
[
j
]
shap_vals
=
np
.
load
(
f
'./output/shap_values/{group}_{method_name}.npy'
)
print
(
f
'Loaded SHAP values. Shape: {shap_vals.shape}'
)
shap_inter_vals
=
np
.
load
(
f
'./output/shap_inter_values/{group}_{method_name}.npy'
)
print
(
f
'Loaded SHAP INTER values. Shape: {shap_inter_vals.shape}'
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment