topN.py 2.82 KB
Newer Older
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93
import torch
from deepSnapPred import HeteroGNN, generate_convs_link_pred_layers
from deepsnap.batch import Batch
from torch.utils.data import DataLoader
import argparse
import heterograph_construction
from deepsnap.dataset import GraphDataset
from datetime import datetime
from deepsnap.hetero_gnn import HeteroSAGEConv


def arg_parse():
    parser = argparse.ArgumentParser(description='Link pred arguments.')
    parser.add_argument('--device', type=str,
                        help='CPU / GPU device.')
                       
    parser.add_argument('--n', type=int,
                        help='Number of predictions.')

    parser.set_defaults(
        device='cuda' if torch.cuda.is_available() else 'cpu',
        n=6
    )
    return parser.parse_args()

args = arg_parse()
constructor = heterograph_construction.DISNETConstructor(device=args.device)
toStudy = 'dis_dru_the'
n = args.n       #Number of new predictions, top n.


def filterPreds(original, pred, key):
    headsO = original.edge_index[key][0, :].long()

    new = []
    for i, elem in enumerate(pred):
        pred_labels = pred[key, i]
        head = i
        tail = torch.arange(0,len(pred_labels))
        indexH = ((headsO == head).nonzero(as_tuple=True)[0])
            
        for index in indexH:
            tail = tail[tail != index]
        
        new.append([head, tail, pred_labels[tail].cpu().detach().numpy()])

    return new


def getTopNDS(model, original, dataloader, key, n):
    print("     Looking for new edges.")
    for batch, original in zip(dataloader, original):
        batch.to(args.device)
        pred = model.predict_all(batch)
        
        new = filterPreds(original, pred, key)
      
    print("     Decoding predictions, this may take a while.")
    return constructor.decodePredictions(new, toStudy, n)

def getOriginal():
    hetero = constructor.DISNETHeterographDeepSnap(full=True)
    dataset = GraphDataset(
        [hetero],
        task='link_pred',
        edge_train_mode='disjoint',
        edge_message_ratio=0.8
    )
    dataset_loader = DataLoader(
        dataset, collate_fn=Batch.collate(), batch_size=1
    )
    return dataset_loader, hetero


def deepSnap():
    original, hetero = getOriginal()
    conv1, conv2 = generate_convs_link_pred_layers(hetero, HeteroSAGEConv, 32)
    model = HeteroGNN(conv1, conv2, hetero, 32).to(args.device)

    model.load_state_dict(torch.load("./models/modelDeepSnapPred", map_location=torch.device(args.device)))
    model = model.to(args.device)
    model.eval()

    edge = ('disorder', 'dis_dru_the', 'drug')
    toInfer, _ = getOriginal()
    
    print("Started getting top", n, "at", datetime.now().strftime("%H:%M:%S"))
    topN = getTopNDS(model, original, toInfer, edge, n)
    print("Finished getting top",n, "at", datetime.now().strftime("%H:%M:%S"))

if __name__ == '__main__':
    deepSnap()