topN.py 2.38 KB
Newer Older
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
1 2 3 4 5 6 7 8 9 10
import torch
from deepSnapPred import HeteroGNN, generate_convs_link_pred_layers
from deepsnap.batch import Batch
from torch.utils.data import DataLoader
import heterograph_construction
from deepsnap.dataset import GraphDataset
from datetime import datetime
from deepsnap.hetero_gnn import HeteroSAGEConv


ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
11 12
device = 'cuda'
constructor = heterograph_construction.DISNETConstructor(device=device)
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
13
toStudy = 'dis_dru_the'
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
14 15 16 17
n = 1479774      #Number of new predictions, top n.

hidden_dim = 107
dropout = 0.8
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40


def filterPreds(original, pred, key):
    headsO = original.edge_index[key][0, :].long()

    new = []
    for i, elem in enumerate(pred):
        pred_labels = pred[key, i]
        head = i
        tail = torch.arange(0,len(pred_labels))
        indexH = ((headsO == head).nonzero(as_tuple=True)[0])
            
        for index in indexH:
            tail = tail[tail != index]
        
        new.append([head, tail, pred_labels[tail].cpu().detach().numpy()])

    return new


def getTopNDS(model, original, dataloader, key, n):
    print("     Looking for new edges.")
    for batch, original in zip(dataloader, original):
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
41
        batch.to(device)
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
42 43 44 45 46 47 48 49
        pred = model.predict_all(batch)
        
        new = filterPreds(original, pred, key)
      
    print("     Decoding predictions, this may take a while.")
    return constructor.decodePredictions(new, toStudy, n)

def getOriginal():
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
50
    hetero, _ = constructor.DISNETHeterographDeepSnap(full=True)
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
51 52 53 54 55 56 57 58 59 60 61 62 63 64
    dataset = GraphDataset(
        [hetero],
        task='link_pred',
        edge_train_mode='disjoint',
        edge_message_ratio=0.8
    )
    dataset_loader = DataLoader(
        dataset, collate_fn=Batch.collate(), batch_size=1
    )
    return dataset_loader, hetero


def deepSnap():
    original, hetero = getOriginal()
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
65 66
    convs = generate_convs_link_pred_layers(hetero, HeteroSAGEConv, hidden_dim)
    model = HeteroGNN(convs, hetero, hidden_dim, dropout).to(device).to(device)
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
67

ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
68
    model.load_state_dict(torch.load("./models/behor", map_location=torch.device(device)))
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
69
    model = model.to(device)
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
70 71 72 73 74 75 76 77 78 79 80 81
    model.eval()

    edge = ('disorder', 'dis_dru_the', 'drug')
    toInfer, _ = getOriginal()
    
    print("Started getting top", n, "at", datetime.now().strftime("%H:%M:%S"))
    topN = getTopNDS(model, original, toInfer, edge, n)
    print("Finished getting top",n, "at", datetime.now().strftime("%H:%M:%S"))

if __name__ == '__main__':
    deepSnap()