Commit f3996047 authored by ADRIAN  AYUSO MUNOZ's avatar ADRIAN AYUSO MUNOZ

Update to topN.

parent c7af1974
,dis,dis name,dru,dru name,pred
0,C0751137,Craniofacial Pain,CHEMBL1490,TRIHEXYPHENIDYL,1.0
1,C0086237,"Epilepsy, Cryptogenic",CHEMBL641,ATOMOXETINE,1.0
2,C0270824,Visual epilepsy,CHEMBL3833361,PROMETHAZINE TEOCLATE,1.0
3,C3151568,"NEPHROTIC SYNDROME, TYPE 4",CHEMBL121,ROSIGLITAZONE,1.0
4,C0086237,"Epilepsy, Cryptogenic",CHEMBL640,PROCAINAMIDE,1.0
5,C0270824,Visual epilepsy,CHEMBL3833362,ZINC OLEATE,1.0
6,C0270824,Visual epilepsy,CHEMBL3833364,HYDRARGAPHEN,1.0
7,C0270824,Visual epilepsy,CHEMBL3833368,RUCAPARIB CAMSYLATE,1.0
8,C0270824,Visual epilepsy,CHEMBL3833369,FISH OIL,1.0
9,C0270824,Visual epilepsy,CHEMBL3833373,AVELUMAB,1.0
10,C0270824,Visual epilepsy,CHEMBL3833381,MERBROMIN,1.0
11,C0270824,Visual epilepsy,CHEMBL3833382,PRAJMALIUM,1.0
12,C0270824,Visual epilepsy,CHEMBL3833383,GALLIUM DOTATATE GA-68,1.0
13,C0270824,Visual epilepsy,CHEMBL3833311,DEXTROMORAMIDE TARTRATE,1.0
14,C0270824,Visual epilepsy,CHEMBL3833388,PRAJMALIUM BITARTRATE,1.0
15,C0270824,Visual epilepsy,CHEMBL3833389,DEXTROMORAMIDE,1.0
16,C0270824,Visual epilepsy,CHEMBL3833393,EMICIZUMAB,1.0
17,C0270824,Visual epilepsy,CHEMBL3833401,ALUMINUM CHLORIDE,1.0
18,C0338656,Cognitive Dysfunction,CHEMBL3039594,GENTAMICIN SULFATE,1.0
19,C0162323,Polyarthritis,CHEMBL1003,CLAVULANATE POTASSIUM,1.0
......@@ -147,9 +147,9 @@ def metrics(model):
pure_predictions = [item for sublist in [preds, predsN] for item in sublist]
labels = torch.tensor([item for sublist in [labels1, labels2] for item in sublist])
fpr, tpr, label1 = plot_roc(labels, pure_predictions, ('disorder', 'dis_dru_the', 'drug'), "dmsr-f/",
fpr, tpr, label1 = plot_roc(labels, pure_predictions, ('disorder', 'dis_dru_the', 'drug'), "dmsr/",
"RepoDB")
recall, precision, label2 = plot_prc(labels, pure_predictions, ('disorder', 'dis_dru_the', 'drug'), "dmsr-f/",
recall, precision, label2 = plot_prc(labels, pure_predictions, ('disorder', 'dis_dru_the', 'drug'), "dmsr/",
"RepoDB")
plotMetrics(fpr, tpr, label1, recall, precision, label2)
......
......@@ -159,9 +159,9 @@ def metrics(model):
pure_predictions = [item for sublist in [preds, predsN] for item in sublist]
labels = torch.tensor([item for sublist in [labels1, labels2] for item in sublist])
fpr, tpr, label1 = plot_roc(labels, pure_predictions, ('disorder', 'dis_dru_the', 'drug'), "dmsr-f/",
fpr, tpr, label1 = plot_roc(labels, pure_predictions, ('disorder', 'dis_dru_the', 'drug'), "dmsr/",
"RepoDB")
recall, precision, label2 = plot_prc(labels, pure_predictions, ('disorder', 'dis_dru_the', 'drug'), "dmsr-f/",
recall, precision, label2 = plot_prc(labels, pure_predictions, ('disorder', 'dis_dru_the', 'drug'), "dmsr/",
"RepoDB")
plotMetrics(fpr, tpr, label1, recall, precision, label2)
......
......@@ -7,8 +7,8 @@ from deepsnap.dataset import GraphDataset
from datetime import datetime
from deepsnap.hetero_gnn import HeteroSAGEConv
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # It defines whether to execute on cpu or gpu.
device = 'cpu'
constructor = heterograph_construction.DISNETConstructor(device=device) # Graph constructor.
edge = ('disorder', 'dis_dru_the', 'drug') # Graph edge type to study.
n = 200000 # Number of new predictions.
......@@ -19,8 +19,10 @@ Input:
original: Graph.
pred: Edge predictions.
"""
def filterPreds(original, pred):
headsO = original.edge_index[edge][0, :].long() # Heads of the original edges of the graph.
headsO = original.edge_index[edge][0, :].long() # Heads of the original edges of the graph.
new = []
for i, elem in enumerate(pred):
......@@ -28,11 +30,11 @@ def filterPreds(original, pred):
head = i
tail = torch.arange(0, len(pred_labels)) # All tails.
indexH = ((headsO == head).nonzero(as_tuple=True)[0]) # Index of those heads originally present in the graph.
#Check
print(len(tail))
for index in indexH:
tail = tail[tail != index] # Just get those tails not present in the original graph.
print(len(tail))
# Check
for t in original.edge_index[edge][1, indexH]:
tail = tail[tail != t] # Just get those tails not present in the original graph.
new.append([head, tail, pred_labels[tail].cpu().detach().numpy()]) # New predictions are appended.
return new
......@@ -47,24 +49,29 @@ Input:
Output:
Dataframe containing the top n predictions ordered and decoded.
"""
def getTopN(model, dataloader, n):
print(" Looking for new edges.")
for batch in zip(dataloader):
batch.to(device)
batch = batch[0]
pred = model.predict_all(batch) # Predict all edges.
new = filterPreds(batch, pred) # Filter those edges present in the original graph.
print(" Decoding predictions, this may take a while.")
return constructor.decodePredictions(new, edge[1], n)
"""
It gets the heterograph object and its conversion to dataloader.
Output:
The heterograph and its dataloader.
"""
def getOriginal():
hetero, _ = constructor.DISNETHeterograph()
hetero, _ = constructor.DISNETHeterograph(full=True, withoutRepoDB=False)
dataset = GraphDataset(
[hetero],
task='link_pred',
......@@ -76,17 +83,20 @@ def getOriginal():
)
return dataset_loader, hetero
"""
It wraps all the necessary calls to get the top n predictions of the DMSR model.
"""
def dmsr():
# Necessary instantiations.
original, hetero = getOriginal()
convs = generate_convs_link_pred_layers(hetero, HeteroSAGEConv, 107)
model = HeteroGNN(convs, hetero, 107, 0.8).to(device).to(device)
convs = generate_convs_link_pred_layers(hetero, HeteroSAGEConv, 31)
model = HeteroGNN(convs, hetero, 31, 0.5).to(device)
# Load and prepare for inference model.
model.load_state_dict(torch.load("./models/dmsr", map_location=torch.device(device)))
model.load_state_dict(torch.load("./models/dmsrC", map_location=torch.device(device)))
model = model.to(device)
model.eval()
......@@ -95,6 +105,7 @@ def dmsr():
_ = getTopN(model, original, n)
print("Finished getting top", n, "at", datetime.now().strftime("%H:%M:%S"))
if __name__ == '__main__':
dmsr()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment