Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Submit feedback
Sign in
Toggle navigation
T
TFM_AdrianAyusoMuñoz
Project overview
Project overview
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Disnet
GNNs
TFM_AdrianAyusoMuñoz
Commits
f3996047
Commit
f3996047
authored
Jun 25, 2023
by
ADRIAN AYUSO MUNOZ
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Update to topN.
parent
c7af1974
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
29 additions
and
39 deletions
+29
-39
results/dis_dru_the_20_table.csv
results/dis_dru_the_20_table.csv
+0
-21
testRepoDB.py
testRepoDB.py
+2
-2
testRepoDBWeightsAndBiases.py
testRepoDBWeightsAndBiases.py
+2
-2
topN.py
topN.py
+25
-14
No files found.
results/dis_dru_the_20_table.csv
deleted
100644 → 0
View file @
c7af1974
,dis,dis name,dru,dru name,pred
0,C0751137,Craniofacial Pain,CHEMBL1490,TRIHEXYPHENIDYL,1.0
1,C0086237,"Epilepsy, Cryptogenic",CHEMBL641,ATOMOXETINE,1.0
2,C0270824,Visual epilepsy,CHEMBL3833361,PROMETHAZINE TEOCLATE,1.0
3,C3151568,"NEPHROTIC SYNDROME, TYPE 4",CHEMBL121,ROSIGLITAZONE,1.0
4,C0086237,"Epilepsy, Cryptogenic",CHEMBL640,PROCAINAMIDE,1.0
5,C0270824,Visual epilepsy,CHEMBL3833362,ZINC OLEATE,1.0
6,C0270824,Visual epilepsy,CHEMBL3833364,HYDRARGAPHEN,1.0
7,C0270824,Visual epilepsy,CHEMBL3833368,RUCAPARIB CAMSYLATE,1.0
8,C0270824,Visual epilepsy,CHEMBL3833369,FISH OIL,1.0
9,C0270824,Visual epilepsy,CHEMBL3833373,AVELUMAB,1.0
10,C0270824,Visual epilepsy,CHEMBL3833381,MERBROMIN,1.0
11,C0270824,Visual epilepsy,CHEMBL3833382,PRAJMALIUM,1.0
12,C0270824,Visual epilepsy,CHEMBL3833383,GALLIUM DOTATATE GA-68,1.0
13,C0270824,Visual epilepsy,CHEMBL3833311,DEXTROMORAMIDE TARTRATE,1.0
14,C0270824,Visual epilepsy,CHEMBL3833388,PRAJMALIUM BITARTRATE,1.0
15,C0270824,Visual epilepsy,CHEMBL3833389,DEXTROMORAMIDE,1.0
16,C0270824,Visual epilepsy,CHEMBL3833393,EMICIZUMAB,1.0
17,C0270824,Visual epilepsy,CHEMBL3833401,ALUMINUM CHLORIDE,1.0
18,C0338656,Cognitive Dysfunction,CHEMBL3039594,GENTAMICIN SULFATE,1.0
19,C0162323,Polyarthritis,CHEMBL1003,CLAVULANATE POTASSIUM,1.0
testRepoDB.py
View file @
f3996047
...
...
@@ -147,9 +147,9 @@ def metrics(model):
pure_predictions
=
[
item
for
sublist
in
[
preds
,
predsN
]
for
item
in
sublist
]
labels
=
torch
.
tensor
([
item
for
sublist
in
[
labels1
,
labels2
]
for
item
in
sublist
])
fpr
,
tpr
,
label1
=
plot_roc
(
labels
,
pure_predictions
,
(
'disorder'
,
'dis_dru_the'
,
'drug'
),
"dmsr
-f
/"
,
fpr
,
tpr
,
label1
=
plot_roc
(
labels
,
pure_predictions
,
(
'disorder'
,
'dis_dru_the'
,
'drug'
),
"dmsr/"
,
"RepoDB"
)
recall
,
precision
,
label2
=
plot_prc
(
labels
,
pure_predictions
,
(
'disorder'
,
'dis_dru_the'
,
'drug'
),
"dmsr
-f
/"
,
recall
,
precision
,
label2
=
plot_prc
(
labels
,
pure_predictions
,
(
'disorder'
,
'dis_dru_the'
,
'drug'
),
"dmsr/"
,
"RepoDB"
)
plotMetrics
(
fpr
,
tpr
,
label1
,
recall
,
precision
,
label2
)
...
...
testRepoDBWeightsAndBiases.py
View file @
f3996047
...
...
@@ -159,9 +159,9 @@ def metrics(model):
pure_predictions
=
[
item
for
sublist
in
[
preds
,
predsN
]
for
item
in
sublist
]
labels
=
torch
.
tensor
([
item
for
sublist
in
[
labels1
,
labels2
]
for
item
in
sublist
])
fpr
,
tpr
,
label1
=
plot_roc
(
labels
,
pure_predictions
,
(
'disorder'
,
'dis_dru_the'
,
'drug'
),
"dmsr
-f
/"
,
fpr
,
tpr
,
label1
=
plot_roc
(
labels
,
pure_predictions
,
(
'disorder'
,
'dis_dru_the'
,
'drug'
),
"dmsr/"
,
"RepoDB"
)
recall
,
precision
,
label2
=
plot_prc
(
labels
,
pure_predictions
,
(
'disorder'
,
'dis_dru_the'
,
'drug'
),
"dmsr
-f
/"
,
recall
,
precision
,
label2
=
plot_prc
(
labels
,
pure_predictions
,
(
'disorder'
,
'dis_dru_the'
,
'drug'
),
"dmsr/"
,
"RepoDB"
)
plotMetrics
(
fpr
,
tpr
,
label1
,
recall
,
precision
,
label2
)
...
...
topN.py
View file @
f3996047
...
...
@@ -7,8 +7,8 @@ from deepsnap.dataset import GraphDataset
from
datetime
import
datetime
from
deepsnap.hetero_gnn
import
HeteroSAGEConv
device
=
torch
.
device
(
'cuda'
if
torch
.
cuda
.
is_available
()
else
'cpu'
)
# It defines whether to execute on cpu or gpu.
device
=
'cpu'
constructor
=
heterograph_construction
.
DISNETConstructor
(
device
=
device
)
# Graph constructor.
edge
=
(
'disorder'
,
'dis_dru_the'
,
'drug'
)
# Graph edge type to study.
n
=
200000
# Number of new predictions.
...
...
@@ -19,8 +19,10 @@ Input:
original: Graph.
pred: Edge predictions.
"""
def
filterPreds
(
original
,
pred
):
headsO
=
original
.
edge_index
[
edge
][
0
,
:]
.
long
()
# Heads of the original edges of the graph.
headsO
=
original
.
edge_index
[
edge
][
0
,
:]
.
long
()
# Heads of the original edges of the graph.
new
=
[]
for
i
,
elem
in
enumerate
(
pred
):
...
...
@@ -28,11 +30,11 @@ def filterPreds(original, pred):
head
=
i
tail
=
torch
.
arange
(
0
,
len
(
pred_labels
))
# All tails.
indexH
=
((
headsO
==
head
)
.
nonzero
(
as_tuple
=
True
)[
0
])
# Index of those heads originally present in the graph.
#Check
print
(
len
(
tail
))
for
index
in
indexH
:
tail
=
tail
[
tail
!=
index
]
# Just get those tails not present in the original graph.
print
(
len
(
tail
))
# Check
for
t
in
original
.
edge_index
[
edge
][
1
,
indexH
]
:
tail
=
tail
[
tail
!=
t
]
# Just get those tails not present in the original graph.
new
.
append
([
head
,
tail
,
pred_labels
[
tail
]
.
cpu
()
.
detach
()
.
numpy
()])
# New predictions are appended.
return
new
...
...
@@ -47,24 +49,29 @@ Input:
Output:
Dataframe containing the top n predictions ordered and decoded.
"""
def
getTopN
(
model
,
dataloader
,
n
):
print
(
" Looking for new edges."
)
for
batch
in
zip
(
dataloader
):
batch
.
to
(
device
)
batch
=
batch
[
0
]
pred
=
model
.
predict_all
(
batch
)
# Predict all edges.
new
=
filterPreds
(
batch
,
pred
)
# Filter those edges present in the original graph.
print
(
" Decoding predictions, this may take a while."
)
return
constructor
.
decodePredictions
(
new
,
edge
[
1
],
n
)
"""
It gets the heterograph object and its conversion to dataloader.
Output:
The heterograph and its dataloader.
"""
def
getOriginal
():
hetero
,
_
=
constructor
.
DISNETHeterograph
()
hetero
,
_
=
constructor
.
DISNETHeterograph
(
full
=
True
,
withoutRepoDB
=
False
)
dataset
=
GraphDataset
(
[
hetero
],
task
=
'link_pred'
,
...
...
@@ -76,17 +83,20 @@ def getOriginal():
)
return
dataset_loader
,
hetero
"""
It wraps all the necessary calls to get the top n predictions of the DMSR model.
"""
def
dmsr
():
# Necessary instantiations.
original
,
hetero
=
getOriginal
()
convs
=
generate_convs_link_pred_layers
(
hetero
,
HeteroSAGEConv
,
107
)
model
=
HeteroGNN
(
convs
,
hetero
,
107
,
0.8
)
.
to
(
device
)
.
to
(
device
)
convs
=
generate_convs_link_pred_layers
(
hetero
,
HeteroSAGEConv
,
31
)
model
=
HeteroGNN
(
convs
,
hetero
,
31
,
0.5
)
.
to
(
device
)
# Load and prepare for inference model.
model
.
load_state_dict
(
torch
.
load
(
"./models/dmsr"
,
map_location
=
torch
.
device
(
device
)))
model
.
load_state_dict
(
torch
.
load
(
"./models/dmsr
C
"
,
map_location
=
torch
.
device
(
device
)))
model
=
model
.
to
(
device
)
model
.
eval
()
...
...
@@ -95,6 +105,7 @@ def dmsr():
_
=
getTopN
(
model
,
original
,
n
)
print
(
"Finished getting top"
,
n
,
"at"
,
datetime
.
now
()
.
strftime
(
"
%
H:
%
M:
%
S"
))
if
__name__
==
'__main__'
:
dmsr
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment