testRepoDB.py 5.01 KB
Newer Older
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135
import numpy as np
import torch
from deepSnapPred import HeteroGNN, generate_convs_link_pred_layers, main
from deepsnap.batch import Batch
from torch.utils.data import DataLoader
import argparse
import heterograph_construction
from deepsnap.dataset import GraphDataset
from datetime import datetime
from deepsnap.hetero_gnn import HeteroSAGEConv
from utilities import plot_roc, plot_prc, plot_dist
import matplotlib.pyplot as plt


def arg_parse():
    parser = argparse.ArgumentParser(description='Link pred arguments.')
    parser.add_argument('--device', type=str,
                        help='CPU / GPU device.')
    parser.set_defaults(
        device= 'cuda' if torch.cuda.is_available() else 'cpu',
    )
    return parser.parse_args()


args = arg_parse()
constructor = heterograph_construction.DISNETConstructor(device=args.device)
toStudy = 'dis_dru_the'

def getTopNDS(model, eid, dataloader, random=''):
    print("     Looking for new edges.")
    for batch in dataloader:
        batch.to(args.device)
        preds = model.pred(batch, eid)
        new = []
        for i, pred in enumerate(preds):
            new.append([eid[0][i].item(), eid[1][i].item(), pred.cpu().detach().numpy().item()])

    n = len(preds)
    print("     Decoding predictions, this may take a while.")
    return constructor.decodePredictions(new, toStudy, n, True,random), torch.tensor(preds).cpu().detach()


def randomEids():
    tensor1 = torch.randint(0, 30729, (5013,),device=torch.device(args.device))
    tensor2 = torch.randint(0, 3944, (5013,),device=torch.device(args.device))
    return (tensor1, tensor2)


def plotMetrics(extension, fpr, tpr, label1, recall, precision, label2):
    fig, axs = plt.subplots(2, figsize=(6, 10))
    
    axs[0].plot(fpr, tpr, label="AUC ROC = " + np.array2string(label1, formatter={'float_kind': lambda x: "%.2f" % x}))
    axs[0].set_title('ROC Curve')
    axs[0].legend(loc='lower right')
    axs[0].plot([0, 1], [0, 1], 'r--')
    axs[0].set_xlim([0, 1])
    axs[0].set_ylim([0, 1])
    axs[0].set_ylabel('True Positive Rate')
    axs[0].set_xlabel('False Positive Rate')
    
    axs[1].set_title('Precision-Recall Curve')
    axs[1].plot(recall, precision, label="PRC = " + np.array2string(label2, formatter={'float_kind': lambda x: "%.2f" % x}))
    axs[1].legend(loc='lower right')
    axs[1].set_xlim([0, 1])
    axs[1].set_ylim([0, 1])
    axs[1].set_ylabel('Precision')
    axs[1].set_xlabel('Recall')
     
    fig2, axs2 = plt.subplots(1, 2, figsize=(12, 4))
    
    axs2[0].plot(fpr, tpr, label="AUC ROC = " + np.array2string(label1, formatter={'float_kind': lambda x: "%.2f" % x}))
    axs2[0].set_title('ROC Curve')
    axs2[0].legend(loc='lower right')
    axs2[0].plot([0, 1], [0, 1], 'r--')
    axs2[0].set_xlim([0, 1])
    axs2[0].set_ylim([0, 1])
    axs2[0].set_ylabel('True Positive Rate')
    axs2[0].set_xlabel('False Positive Rate')
    
    axs2[1].set_title('Precision-Recall Curve')
    axs2[1].plot(recall, precision, label="PRC = " + np.array2string(label2, formatter={'float_kind': lambda x: "%.2f" % x}))
    axs2[1].legend(loc='lower right')
    axs2[1].set_xlim([0, 1])
    axs2[1].set_ylim([0, 1])
    axs2[1].set_ylabel('Precision')
    axs2[1].set_xlabel('Recall')

    fig.savefig('plots/'+ extension + '/metrics/aucroc&prcRepoDBVertical.svg', format='svg', dpi=1200)
    fig2.savefig('plots/'+ extension + '/metrics/aucroc&prcRepoDBHorizontal.svg', format='svg', dpi=1200)
    plot_dist(extension + '/')


def deepSnapMetrics():
    hetero, eids = constructor.DISNETHeterographDeepSnap(all=True, withoutRepoDB=True)
    dataset = GraphDataset(
        [hetero],
        task='link_pred',
        edge_train_mode='disjoint',
        edge_message_ratio=0.8
    )
    toInfer = DataLoader(
        dataset, collate_fn=Batch.collate(), batch_size=1
    )
   
    hidden_dim = 45
    convs = generate_convs_link_pred_layers(hetero, HeteroSAGEConv, hidden_dim)
    
    model = HeteroGNN(convs, hetero, hidden_dim).to(args.device)
    model.load_state_dict(torch.load("./models/modelDeepSnapPred", map_location=torch.device(args.device)))
    model = model.to(args.device)
    model.eval()

    print("Started getting repoDB predictions at", datetime.now().strftime("%H:%M:%S"))
    topN, preds = getTopNDS(model,  eids, toInfer)
    print("Finished getting repoDB predictions at", datetime.now().strftime("%H:%M:%S"))

    topN, predsN = getTopNDS(model, randomEids(), toInfer, 'R')
      
    labels1 = torch.ones(len(preds))
    labels2 = torch.zeros(len(predsN))
    
    pure_predictions = [item for sublist in [preds,predsN] for item in sublist]
    labels = torch.tensor([item for sublist in [labels1,labels2] for item in sublist])
    
    fpr, tpr, label1 = plot_roc(labels, pure_predictions,('disorder', 'dis_dru_the', 'drug'), "deepSnapPred/", "RepoDB")
    recall, precision, label2 = plot_prc(labels, pure_predictions, ('disorder', 'dis_dru_the', 'drug'), "deepSnapPred/", "RepoDB")
    
    plotMetrics("deepSnapPred", fpr, tpr, label1, recall, precision, label2)
        


if __name__ == '__main__':
    main()
    deepSnapMetrics()