Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Submit feedback
Sign in
Toggle navigation
T
TFM_AdrianAyusoMuñoz
Project overview
Project overview
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Disnet
GNNs
TFM_AdrianAyusoMuñoz
Commits
6d263e1a
Commit
6d263e1a
authored
Jun 28, 2023
by
ADRIAN AYUSO MUNOZ
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Reformat files
parent
852e50e1
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
60 additions
and
64 deletions
+60
-64
autoencoder.py
autoencoder.py
+4
-0
dmsr.py
dmsr.py
+4
-3
drug_embedding_generator.py
drug_embedding_generator.py
+6
-1
heterograph_construction.py
heterograph_construction.py
+7
-6
testRepoDB.py
testRepoDB.py
+8
-9
testRepoDBWeightsAndBiases.py
testRepoDBWeightsAndBiases.py
+1
-9
topN.py
topN.py
+0
-8
utilities.py
utilities.py
+30
-28
No files found.
autoencoder.py
View file @
6d263e1a
...
...
@@ -14,6 +14,7 @@ class GCNEncoder(torch.nn.Module):
in_channels: Size of the input embeddings.
out_channels: Size of the output embeddings.
"""
def
__init__
(
self
,
in_channels
,
out_channels
):
super
(
GCNEncoder
,
self
)
.
__init__
()
self
.
conv1
=
SAGEConv
(
in_channels
,
2
*
out_channels
)
# cached only for transductive learning
...
...
@@ -28,11 +29,13 @@ class GCNEncoder(torch.nn.Module):
Output:
Encoded embeddings.
"""
def
forward
(
self
,
x
,
edge_index
):
x
=
self
.
conv1
(
x
,
edge_index
)
x
=
self
.
lrelu
(
x
)
return
self
.
conv2
(
x
,
edge_index
)
"""
Class to wrap the training and testing functions.
"""
...
...
@@ -44,6 +47,7 @@ class Trainer:
optimizer: Optimizer to train the model.
dataset: Dataset on which the model will be trained.
"""
def
__init__
(
self
,
model
,
optimizer
,
dataset
):
self
.
model
=
model
self
.
optimizer
=
optimizer
...
...
dmsr.py
View file @
6d263e1a
...
...
@@ -28,7 +28,6 @@ from deepsnap.hetero_gnn import (
edges
=
[(
'disorder'
,
'dis_dru_the'
,
'drug'
)]
# It sets the edges that will be studied.
device
=
torch
.
device
(
'cuda'
if
torch
.
cuda
.
is_available
()
else
'cpu'
)
# It defines whether to execute on cpu or gpu.
# ---------------------------
# FUNCTIONS
# ---------------------------
...
...
@@ -55,6 +54,7 @@ def generate_convs_link_pred_layers(hete, conv, hidden_size):
convs2
[
message_type
]
=
conv
(
hidden_size
,
hidden_size
,
hidden_size
)
return
[
convs1
,
convs2
]
"""
Heterogeneous Graph Neural Network.
"""
...
...
@@ -228,6 +228,7 @@ class HeteroGNN(torch.nn.Module):
loss
+=
self
.
loss_fn
(
p
,
y
[
key
]
.
type
(
pred
[
key
]
.
dtype
))
return
loss
"""
It trains the model.
Input:
...
...
@@ -345,6 +346,7 @@ def test2(model, test_loader):
keys
.
append
(
key
)
return
pure_pred_labels
,
true_labels
,
keys
"""
It wraps all the functions for training and testing the model.
Input:
...
...
@@ -420,5 +422,4 @@ def main(epochs, hidden_dim, lr, weight_decay, dropout):
if
__name__
==
'__main__'
:
# Set of hyperparameters to train the model.
main
(
2343
,
31
,
0.0010235455088934942
,
0.005144745056173074
,
0.5
)
main
(
2343
,
31
,
0.0010235455088934942
,
0.005144745056173074
,
0.5
)
drug_embedding_generator.py
View file @
6d263e1a
...
...
@@ -24,12 +24,14 @@ class DrugGraph:
neg_edge_index: Two lists which contain negative edges, not present in the graph (first contains heads and
second tails)
"""
def
__init__
(
self
,
graph
,
feats
,
edge_index
,
neg_edge_index
):
self
.
graph
=
graph
self
.
x
=
feats
self
.
edge_index
=
edge_index
self
.
neg_edge_index
=
neg_edge_index
"""
It generates a CSV with the SMILES representations of the DISNET's graph drugs.
"""
...
...
@@ -50,6 +52,7 @@ def getSmiles():
newDf
.
to_csv
(
"data/druStruc.tsv"
,
sep
=
'
\t
'
,
index
=
False
)
"""
It transforms the SMILES representations to their graph representations. In case of error instead of the graph a 0 is
given.
...
...
@@ -83,6 +86,7 @@ def getGraph(df, complete=False):
df
.
to_csv
(
"data/druStruc.tsv"
,
sep
=
'
\t
'
,
index
=
False
)
return
np
.
array
(
networks
,
dtype
=
object
)
"""
It builds a dataset made of the drugs molecular structures.
Input:
...
...
@@ -169,6 +173,7 @@ def trainAE(model, optimizer, dataset, epochs):
trainer
=
Trainer
(
model
,
optimizer
,
dataset
)
trainer
.
fit
(
epochs
)
"""
It generates an embedding for a given graph applying all the instantiated layers in the previous function. The embedding
is the code of the autoencoder.
...
...
@@ -210,7 +215,7 @@ if __name__ == '__main__':
# Getting Embeddings
embeddings
=
getEmbed
(
model
,
dataset
)
# Saving model.
torch
.
save
(
model
.
state_dict
(),
"./models/structureEmbedder"
)
...
...
heterograph_construction.py
View file @
6d263e1a
...
...
@@ -63,7 +63,7 @@ class DISNETConstructor:
ddi_dru
=
pd
.
read_csv
(
'data/links/ddi_dru.tsv'
,
sep
=
'
\t
'
)
return
dis_dru_the
,
dis_sym
,
dis_pat
,
dis_pro
,
dru_dru
,
dru_pro
,
dru_sym_ind
,
dru_sym_sef
,
pro_pat
,
\
pro_pro
,
ddi_phe
,
ddi_dru
pro_pro
,
ddi_phe
,
ddi_dru
else
:
return
dis_dru_the
,
dis_sym
...
...
@@ -137,7 +137,8 @@ class DISNETConstructor:
ddi
[
'NID'
]
=
ddi
.
index
ddi
[
'node_type'
]
=
'drug-drug-interaction'
ddi
[
'node_id'
]
=
nodes_flat
.
loc
[
nodes_flat
[
'node_type'
]
==
'drug-drug-interaction'
]
.
reset_index
(
drop
=
True
)
.
node_id
ddi
[
'node_id'
]
=
nodes_flat
.
loc
[
nodes_flat
[
'node_type'
]
==
'drug-drug-interaction'
]
.
reset_index
(
drop
=
True
)
.
node_id
# Nodes dataframes to dict to apply map later
pat_dict
=
pat
[[
'id'
,
'node_id'
]]
.
set_index
(
'id'
)
.
to_dict
()[
'node_id'
]
...
...
@@ -148,7 +149,6 @@ class DISNETConstructor:
pro_feat
=
torch
.
tensor
([[
1
]
*
100
]
*
nsizes
[
'protein'
],
dtype
=
torch
.
float32
)
ddi_feat
=
torch
.
tensor
([[
1
]
*
100
]
*
nsizes
[
'drug-drug-interaction'
],
dtype
=
torch
.
float32
)
feats
=
{
'disorder'
:
dis_feat
,
'drug'
:
dru_feat
,
'pathway'
:
pat_feat
,
'protein'
:
pro_feat
,
'drug-drug-interaction'
:
ddi_feat
}
...
...
@@ -163,7 +163,7 @@ class DISNETConstructor:
# DataFrames of each type of edges
if
full
:
dis_dru_the
,
dis_sym
,
dis_pat
,
dis_pro
,
dru_dru
,
dru_pro
,
dru_sym_ind
,
dru_sym_sef
,
pro_pat
,
\
pro_pro
,
ddi_phe
,
ddi_dru
=
self
.
getEdgeInfo
(
full
)
pro_pro
,
ddi_phe
,
ddi_dru
=
self
.
getEdgeInfo
(
full
)
dis_pat
[
'disNID'
]
=
dis_pat
.
dis
.
map
(
dis_dict
)
dis_pat
[
'patNID'
]
=
dis_pat
.
pat
.
map
(
pat_dict
)
...
...
@@ -215,7 +215,7 @@ class DISNETConstructor:
dis_dru_the
=
pd
.
concat
([
dis_dru_the
,
dis_dru_the_repoDBAll
])
dis_dru_the
.
drop_duplicates
(
keep
=
False
,
inplace
=
True
,
ignore_index
=
True
)
dis_dru_the
=
dis_dru_the
[:
50355
]
# Those of RepoDB added for the concatenation are deleted.
dis_dru_the
=
dis_dru_the
[:
50355
]
# Those of RepoDB added for the concatenation are deleted.
dis_dru_the_repoDBAll
=
(
torch
.
tensor
(
dis_dru_the_repoDBAll
[
'disNID'
]
.
astype
(
np
.
int32
)
.
to_numpy
(),
dtype
=
torch
.
int32
,
...
...
@@ -321,7 +321,8 @@ class DISNETConstructor:
# Add the edges to the graph.
for
edge_t
in
edges_dict
.
keys
():
for
edge
in
edges
[
edge_t
]:
# If a relation between a phenotype and a drug is in more than one type. The final type will be the last one.
for
edge
in
edges
[
edge_t
]:
# If a relation between a phenotype and a drug is in more than one type. The final type will be the last one.
try
:
G
.
add_edge
(
int
(
edge
[
0
]),
int
(
edge
[
1
]),
edge_feature
=
edge
[
2
],
edge_type
=
edge_t
)
except
IndexError
:
...
...
testRepoDB.py
View file @
6d263e1a
...
...
@@ -15,7 +15,6 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # It defi
constructor
=
heterograph_construction
.
DISNETConstructor
(
device
=
device
)
# Graph constructor.
toStudy
=
'dis_dru_the'
# Graph edge type to study.
"""
It gets the predictions of the model and decodes them.
Input:
...
...
@@ -50,6 +49,7 @@ def randomEids():
tensor2
=
torch
.
randint
(
0
,
3944
,
(
5013
,),
device
=
torch
.
device
(
device
))
return
(
tensor1
,
tensor2
)
"""
It plots the metrics for the results of the real edge and the random edge set. It joins them vertically and horizontally.
Input:
...
...
@@ -161,25 +161,25 @@ if __name__ == '__main__':
rocL
,
prcL
=
np
.
array
([]),
np
.
array
([])
# Number of iterations.
k
=
50
# Set of hyperparameters.
epochs
=
2343
epochs
=
2343
hidden_dim
=
31
lr
=
0.0010235455088934942
weight_decay
=
0.005144745056173074
dropout
=
0.5
# Train and test k models and obtain their metrics.
for
i
in
range
(
k
):
model
=
main
(
epochs
,
hidden_dim
,
lr
,
weight_decay
,
dropout
)
roc1
,
prc1
=
metrics
(
model
)
rocL
=
np
.
append
(
rocL
,
roc1
)
prcL
=
np
.
append
(
prcL
,
prc1
)
# Average of the metrics of all the generated models.
rocM
=
sum
(
rocL
)
/
k
prcM
=
sum
(
prcL
)
/
k
# Obtain confidence intervals, if number of samples is under 30 t-distribution is used, if over 30 the normal
# distribution is used.
if
k
<
30
:
...
...
@@ -189,6 +189,5 @@ if __name__ == '__main__':
r
=
st
.
norm
.
interval
(
0.95
,
loc
=
np
.
mean
(
rocL
),
scale
=
st
.
sem
(
rocL
))
p
=
st
.
norm
.
interval
(
0.95
,
loc
=
np
.
mean
(
prcL
),
scale
=
st
.
sem
(
prcL
))
print
(
"AUCROC: "
,
rocM
,
"+-"
,
rocM
-
r
[
0
])
print
(
"AUCPR: "
,
prcM
,
"+-"
,
prcM
-
p
[
0
])
print
(
"AUCROC: "
,
rocM
,
"+-"
,
rocM
-
r
[
0
])
print
(
"AUCPR: "
,
prcM
,
"+-"
,
prcM
-
p
[
0
])
testRepoDBWeightsAndBiases.py
View file @
6d263e1a
...
...
@@ -13,6 +13,7 @@ import scipy.stats as st
# Import W&B
import
wandb
wandb
.
init
(
project
=
"dmsr"
,
entity
=
"ayusoupm"
)
device
=
torch
.
device
(
'cuda'
if
torch
.
cuda
.
is_available
()
else
'cpu'
)
# It defines whether to execute on cpu or gpu.
...
...
@@ -29,8 +30,6 @@ Input:
Output:
Dataframe containing all the predictions ordered and decoded.
"""
def
getDecode
(
model
,
eid
,
dataloader
,
random
=
''
):
print
(
" Looking for new edges."
)
for
batch
in
dataloader
:
...
...
@@ -50,8 +49,6 @@ It generates random edges in the graph.
Output:
Randomly generated edges.
"""
def
randomEids
():
tensor1
=
torch
.
randint
(
0
,
30729
,
(
5013
,),
device
=
torch
.
device
(
device
))
tensor2
=
torch
.
randint
(
0
,
3944
,
(
5013
,),
device
=
torch
.
device
(
device
))
...
...
@@ -68,8 +65,6 @@ Input:
precision: Precision.
label2: Area Under the PR curve.
"""
def
plotMetrics
(
fpr
,
tpr
,
label1
,
recall
,
precision
,
label2
):
# Vertical plotting.
fig
,
axs
=
plt
.
subplots
(
2
,
figsize
=
(
6
,
10
))
...
...
@@ -127,8 +122,6 @@ Input:
Output:
Area Under the ROC and PR curve.
"""
def
metrics
(
model
):
hetero
,
eids
=
constructor
.
DISNETHeterograph
(
full
=
False
,
withoutRepoDB
=
True
)
dataset
=
GraphDataset
(
...
...
@@ -210,4 +203,3 @@ if __name__ == '__main__':
print
(
"AUCROC: "
,
rocM
,
"+-"
,
rocM
-
r
[
0
])
print
(
"AUCPR: "
,
prcM
,
"+-"
,
prcM
-
p
[
0
])
topN.py
View file @
6d263e1a
...
...
@@ -19,8 +19,6 @@ Input:
original: Graph.
pred: Edge predictions.
"""
def
filterPreds
(
original
,
pred
):
headsO
=
original
.
edge_index
[
edge
][
0
,
:]
.
long
()
# Heads of the original edges of the graph.
...
...
@@ -49,8 +47,6 @@ Input:
Output:
Dataframe containing the top n predictions ordered and decoded.
"""
def
getTopN
(
model
,
dataloader
,
n
):
print
(
" Looking for new edges."
)
for
batch
in
zip
(
dataloader
):
...
...
@@ -68,8 +64,6 @@ It gets the heterograph object and its conversion to dataloader.
Output:
The heterograph and its dataloader.
"""
def
getOriginal
():
hetero
,
_
=
constructor
.
DISNETHeterograph
(
full
=
True
,
withoutRepoDB
=
False
)
dataset
=
GraphDataset
(
...
...
@@ -88,7 +82,6 @@ def getOriginal():
It wraps all the necessary calls to get the top n predictions of the DMSR model.
"""
def
dmsr
():
# Necessary instantiations.
original
,
hetero
=
getOriginal
()
...
...
@@ -108,4 +101,3 @@ def dmsr():
if
__name__
==
'__main__'
:
dmsr
()
utilities.py
View file @
6d263e1a
...
...
@@ -3,7 +3,6 @@ import pandas as pd
import
sklearn.metrics
as
metrics
import
matplotlib.pyplot
as
plt
"""
Calculate and plot ROC curve.
Input:
...
...
@@ -16,12 +15,12 @@ Output:
"""
def
plot_roc
(
y_test
,
preds
,
edge
,
extension
=
''
):
fpr
,
tpr
,
threshold
=
metrics
.
roc_curve
(
y_test
,
preds
)
roc_auc
=
metrics
.
auc
(
fpr
,
tpr
)
roc_auc
=
metrics
.
auc
(
fpr
,
tpr
)
label
=
" "
.
join
(
edge
)
+
" "
+
'AUC =
%0.2
f'
%
roc_auc
fileName
=
'metrics/aucroc'
+
extension
+
'.svg'
fileName
=
'metrics/aucroc'
+
extension
+
'.svg'
random
=
[[
0
,
1
],
[
0
,
1
],
'r--'
]
plotAndSaveFig
(
title
=
'ROC Curve'
,
x
=
fpr
,
y
=
tpr
,
label
=
label
,
path
=
fileName
,
loc
=
'lower right'
,
xlim
=
[
0
,
1
],
ylim
=
[
0
,
1
],
xlabel
=
'False Positive Rate'
,
ylabel
=
'True Positive Rate'
,
random
=
random
)
...
...
@@ -43,16 +42,17 @@ Output:
def
plot_prc
(
y_true
,
y_pred
,
edge
,
extension
=
''
):
precision
,
recall
,
trhesholds
=
metrics
.
precision_recall_curve
(
y_true
,
y_pred
)
average_precision
=
metrics
.
average_precision_score
(
y_true
,
y_pred
)
label
=
" "
.
join
(
edge
)
+
" "
+
'RPC =
%0.2
f'
%
average_precision
fileName
=
'metrics/prc'
+
extension
+
'.svg'
fileName
=
'metrics/prc'
+
extension
+
'.svg'
plotAndSaveFig
(
title
=
'Precision-Recall Curve'
,
x
=
recall
,
y
=
precision
,
label
=
label
,
path
=
fileName
,
loc
=
'lower right'
,
xlim
=
[
0
,
1
],
ylim
=
[
0
,
1
],
xlabel
=
'Recall'
,
ylabel
=
'Precision'
)
print
(
"PRC:"
,
average_precision
)
return
recall
,
precision
,
average_precision
"""
Plot distribution of predictions of the true and random edges. Plots the distribution function and histogram.
Input:
...
...
@@ -66,17 +66,18 @@ def plot_dist(extension=''):
ax
.
set_xticks
(
np
.
arange
(
0
,
1.1
,
0.1
))
ax
.
set_title
(
'RepoDB & Random Prediction Histogram'
)
ax
.
set_xlabel
(
'Prediction Score'
)
ax
.
figure
.
savefig
(
'results/histogram'
+
extension
+
'.svg'
,
format
=
'svg'
,
dpi
=
1200
)
ax
.
figure
.
savefig
(
'results/histogram'
+
extension
+
'.svg'
,
format
=
'svg'
,
dpi
=
1200
)
ax
.
figure
.
clf
()
bx
=
df
.
plot
.
kde
()
bx
.
set_title
(
'RepoDB & Random Prediction'
)
bx
.
set_xlabel
(
'Prediction Score'
)
bx
.
figure
.
savefig
(
'results/distribution'
+
extension
+
'.svg'
,
format
=
'svg'
,
dpi
=
1200
)
bx
.
figure
.
savefig
(
'results/distribution'
+
extension
+
'.svg'
,
format
=
'svg'
,
dpi
=
1200
)
bx
.
set_xlim
([
0
,
1
])
bx
.
figure
.
savefig
(
'results/distribution01'
+
extension
+
'.svg'
,
format
=
'svg'
,
dpi
=
1200
)
bx
.
figure
.
savefig
(
'results/distribution01'
+
extension
+
'.svg'
,
format
=
'svg'
,
dpi
=
1200
)
bx
.
figure
.
clf
()
"""
It plots and saves the figure sent.
Input:
...
...
@@ -95,29 +96,30 @@ Input:
def
plotAndSaveFig
(
title
,
x
,
y
,
label
,
path
,
loc
=
None
,
xlim
=
None
,
ylim
=
None
,
xlabel
=
None
,
ylabel
=
None
,
random
=
None
):
plt
.
title
(
title
)
plt
.
plot
(
x
,
y
,
label
=
label
)
if
loc
is
not
None
:
plt
.
legend
(
loc
=
loc
)
if
xlim
is
not
None
:
plt
.
xlim
(
xlim
)
if
ylim
is
not
None
:
plt
.
ylim
(
ylim
)
if
xlabel
is
not
None
:
plt
.
xlabel
(
xlabel
)
if
ylabel
is
not
None
:
plt
.
ylabel
(
ylabel
)
if
random
is
not
None
:
plt
.
plot
(
random
[
0
],
random
[
1
],
random
[
2
])
plt
.
show
()
plt
.
savefig
(
path
,
format
=
'svg'
,
dpi
=
1200
)
plt
.
clf
()
"""
It plots and saves together the figures sent.
Input:
...
...
@@ -136,31 +138,31 @@ Input:
def
plotTogether
(
title
,
x
,
y
,
label
,
path
,
loc
=
None
,
xlim
=
None
,
ylim
=
None
,
xlabel
=
None
,
ylabel
=
None
,
random
=
None
):
figH
,
axs
=
plt
.
subplots
(
2
,
figsize
=
(
6
,
10
))
figV
,
axs2
=
plt
.
subplots
(
1
,
2
,
figsize
=
(
12
,
4
))
for
i
,
elem
in
enumerate
(
axs
):
elem
.
plot
(
x
[
i
],
y
[
i
],
label
=
label
[
i
])
elem
.
set_title
(
title
[
i
])
if
loc
[
i
]
is
not
None
:
elem
.
legend
(
loc
=
loc
[
i
])
if
random
[
i
]
is
not
None
:
elem
.
plot
(
random
[
i
][
0
],
random
[
i
][
1
],
random
[
i
][
2
])
if
xlim
[
i
]
is
not
None
:
elem
.
set_xlim
(
xlim
[
i
])
if
ylim
[
i
]
is
not
None
:
elem
.
set_ylim
(
ylim
[
i
])
if
xlabel
[
i
]
is
not
None
:
elem
.
set_xlabel
(
xlabel
)
if
ylabel
[
i
]
is
not
None
:
elem
.
set_ylabel
(
ylabel
)
axs2
[
0
]
=
axs
[
0
]
axs2
[
1
]
=
axs
[
1
]
figH
.
savefig
(
path
[
0
],
format
=
'svg'
,
dpi
=
1200
)
figV
.
savefig
(
path
[
1
],
format
=
'svg'
,
dpi
=
1200
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment