testRepoDB.py 6.73 KB
Newer Older
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194
import numpy as np
import torch
from dmsr import main
from deepsnap.batch import Batch
from torch.utils.data import DataLoader
import heterograph_construction
from deepsnap.dataset import GraphDataset
from datetime import datetime
from utilities import plot_roc, plot_prc, plot_dist
import matplotlib.pyplot as plt

import scipy.stats as st

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')  # It defines whether to execute on cpu or gpu.
constructor = heterograph_construction.DISNETConstructor(device=device)  # Graph constructor.
toStudy = 'dis_dru_the'  # Graph edge type to study.


"""
It gets the predictions of the model and decodes them.
Input:
    model: Model to generate predictions.
    eid: Edges to predict.
    dataloader: Graph.
    random: Whether the edges to predict are random or not.
Output:
    Dataframe containing all the predictions ordered and decoded.
"""
def getDecode(model, eid, dataloader, random=''):
    print("     Looking for new edges.")
    for batch in dataloader:
        batch.to(device)
        preds = model.pred(batch, eid)
        new = []
        for i, pred in enumerate(preds):
            new.append([eid[0][i].item(), eid[1][i].item(), pred.cpu().detach().numpy().item()])

    n = len(preds)
    print("     Decoding predictions, this may take a while.")
    return constructor.decodePredictions(new, toStudy, n, True, random), torch.tensor(preds).cpu().detach()


"""
It generates random edges in the graph.
Output:
    Randomly generated edges.
"""
def randomEids():
    tensor1 = torch.randint(0, 30729, (5013,), device=torch.device(device))
    tensor2 = torch.randint(0, 3944, (5013,), device=torch.device(device))
    return (tensor1, tensor2)

"""
It plots the metrics for the results of the real edge and the random edge set. It joins them vertically and horizontally.
Input:
    fpr: False positve rate.
    tpr: True positive rate.
    label1: Area Under the ROC curve.
    recall: Recall.
    precision: Precision.
    label2: Area Under the PR curve.
"""
def plotMetrics(fpr, tpr, label1, recall, precision, label2):
    # Vertical plotting.
    fig, axs = plt.subplots(2, figsize=(6, 10))

    axs[0].plot(fpr, tpr, label="AUC ROC = " + np.array2string(label1, formatter={'float_kind': lambda x: "%.2f" % x}))
    axs[0].set_title('ROC Curve')
    axs[0].legend(loc='lower right')
    axs[0].plot([0, 1], [0, 1], 'r--')
    axs[0].set_xlim([0, 1])
    axs[0].set_ylim([0, 1])
    axs[0].set_ylabel('True Positive Rate')
    axs[0].set_xlabel('False Positive Rate')

    axs[1].set_title('Precision-Recall Curve')
    axs[1].plot(recall, precision,
                label="PRC = " + np.array2string(label2, formatter={'float_kind': lambda x: "%.2f" % x}))
    axs[1].legend(loc='lower right')
    axs[1].set_xlim([0, 1])
    axs[1].set_ylim([0, 1])
    axs[1].set_ylabel('Precision')
    axs[1].set_xlabel('Recall')

    # Horizontal plotting.
    fig2, axs2 = plt.subplots(1, 2, figsize=(12, 4))

    axs2[0].plot(fpr, tpr, label="AUC ROC = " + np.array2string(label1, formatter={'float_kind': lambda x: "%.2f" % x}))
    axs2[0].set_title('ROC Curve')
    axs2[0].legend(loc='lower right')
    axs2[0].plot([0, 1], [0, 1], 'r--')
    axs2[0].set_xlim([0, 1])
    axs2[0].set_ylim([0, 1])
    axs2[0].set_ylabel('True Positive Rate')
    axs2[0].set_xlabel('False Positive Rate')

    axs2[1].set_title('Precision-Recall Curve')
    axs2[1].plot(recall, precision,
                 label="PRC = " + np.array2string(label2, formatter={'float_kind': lambda x: "%.2f" % x}))
    axs2[1].legend(loc='lower right')
    axs2[1].set_xlim([0, 1])
    axs2[1].set_ylim([0, 1])
    axs2[1].set_ylabel('Precision')
    axs2[1].set_xlabel('Recall')

    fig.savefig('metrics/aucroc&prcRepoDBVertical.svg', format='svg', dpi=1200)
    fig2.savefig('metrics/aucroc&prcRepoDBHorizontal.svg', format='svg', dpi=1200)
    plt.close(fig)
    plt.close(fig2)
    plot_dist()


"""
It generates the metrics for the model.
Input:
    model: Model to generate metrics of.
Output:
    Area Under the ROC and PR curve.
"""
def metrics(model):
    hetero, eids = constructor.DISNETHeterograph(full=False, withoutRepoDB=True)
    dataset = GraphDataset(
        [hetero],
        task='link_pred',
        edge_train_mode='disjoint',
        edge_message_ratio=0.8
    )
    toInfer = DataLoader(
        dataset, collate_fn=Batch.collate(), batch_size=1
    )

    model = model.to(device)
    model.eval()

    print("Started getting repoDB predictions at", datetime.now().strftime("%H:%M:%S"))
    _, preds = getDecode(model, eids, toInfer)
    print("Finished getting repoDB predictions at", datetime.now().strftime("%H:%M:%S"))

    print("Started getting random predictions at", datetime.now().strftime("%H:%M:%S"))
    _, predsN = getDecode(model, randomEids(), toInfer, 'R')
    print("Finished getting random predictions at", datetime.now().strftime("%H:%M:%S"))

    labels1 = torch.ones(len(preds))
    labels2 = torch.zeros(len(predsN))

    # Join real and random edge results in one list to calculate metrics.
    pure_predictions = [item for sublist in [preds, predsN] for item in sublist]
    labels = torch.tensor([item for sublist in [labels1, labels2] for item in sublist])

    fpr, tpr, label1 = plot_roc(labels, pure_predictions, ('disorder', 'dis_dru_the', 'drug'), "dmsr-f/",
                                "RepoDB")
    recall, precision, label2 = plot_prc(labels, pure_predictions, ('disorder', 'dis_dru_the', 'drug'), "dmsr-f/",
                                         "RepoDB")

    plotMetrics(fpr, tpr, label1, recall, precision, label2)
    return label1, label2


if __name__ == '__main__':
    # Metrics list for each model.
    rocL, prcL = np.array([]), np.array([])
    # Number of iterations.
    k = 50
    
    # Set of hyperparameters.
    epochs = 2752
    hidden_dim = 107
    lr = 0.0008317
    weight_decay = 0.006314
    dropout = 0.8
    
    # Train and test k models and obtain their metrics.
    for i in range(k):
        model = main(epochs, hidden_dim, lr, weight_decay, dropout)
        roc1, prc1 = metrics(model)
        rocL = np.append(rocL, roc1)
        prcL = np.append(prcL, prc1)
       
    # Average of the metrics of all the generated models.
    rocM = sum(rocL) / k
    prcM = sum(prcL) / k
    
    # Obtain confidence intervals, if number of samples is under 30 t-distribution is used, if over 30 the normal
    # distribution is used.
    if k < 30:
        r = st.t.interval(0.95, k - 1, loc=np.mean(rocL), scale=st.sem(rocL))
        p = st.t.interval(0.95, k - 1, loc=np.mean(prcL), scale=st.sem(prcL))
    else:
        r = st.norm.interval(0.95, loc=np.mean(rocL), scale=st.sem(rocL))
        p = st.norm.interval(0.95, loc=np.mean(prcL), scale=st.sem(prcL))

    print("AUCROC:  ", rocM, "+-", rocM-r[0])
    print("AUCPR:   ", prcM, "+-", prcM-p[0])