testRepoDB.py 6.73 KB
Newer Older
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51
import numpy as np
import torch
from dmsr import main
from deepsnap.batch import Batch
from torch.utils.data import DataLoader
import heterograph_construction
from deepsnap.dataset import GraphDataset
from datetime import datetime
from utilities import plot_roc, plot_prc, plot_dist
import matplotlib.pyplot as plt

import scipy.stats as st

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')  # It defines whether to execute on cpu or gpu.
constructor = heterograph_construction.DISNETConstructor(device=device)  # Graph constructor.
toStudy = 'dis_dru_the'  # Graph edge type to study.

"""
It gets the predictions of the model and decodes them.
Input:
    model: Model to generate predictions.
    eid: Edges to predict.
    dataloader: Graph.
    random: Whether the edges to predict are random or not.
Output:
    Dataframe containing all the predictions ordered and decoded.
"""
def getDecode(model, eid, dataloader, random=''):
    print("     Looking for new edges.")
    for batch in dataloader:
        batch.to(device)
        preds = model.pred(batch, eid)
        new = []
        for i, pred in enumerate(preds):
            new.append([eid[0][i].item(), eid[1][i].item(), pred.cpu().detach().numpy().item()])

    n = len(preds)
    print("     Decoding predictions, this may take a while.")
    return constructor.decodePredictions(new, toStudy, n, True, random), torch.tensor(preds).cpu().detach()


"""
It generates random edges in the graph.
Output:
    Randomly generated edges.
"""
def randomEids():
    tensor1 = torch.randint(0, 30729, (5013,), device=torch.device(device))
    tensor2 = torch.randint(0, 3944, (5013,), device=torch.device(device))
    return (tensor1, tensor2)

ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
52

ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149
"""
It plots the metrics for the results of the real edge and the random edge set. It joins them vertically and horizontally.
Input:
    fpr: False positve rate.
    tpr: True positive rate.
    label1: Area Under the ROC curve.
    recall: Recall.
    precision: Precision.
    label2: Area Under the PR curve.
"""
def plotMetrics(fpr, tpr, label1, recall, precision, label2):
    # Vertical plotting.
    fig, axs = plt.subplots(2, figsize=(6, 10))

    axs[0].plot(fpr, tpr, label="AUC ROC = " + np.array2string(label1, formatter={'float_kind': lambda x: "%.2f" % x}))
    axs[0].set_title('ROC Curve')
    axs[0].legend(loc='lower right')
    axs[0].plot([0, 1], [0, 1], 'r--')
    axs[0].set_xlim([0, 1])
    axs[0].set_ylim([0, 1])
    axs[0].set_ylabel('True Positive Rate')
    axs[0].set_xlabel('False Positive Rate')

    axs[1].set_title('Precision-Recall Curve')
    axs[1].plot(recall, precision,
                label="PRC = " + np.array2string(label2, formatter={'float_kind': lambda x: "%.2f" % x}))
    axs[1].legend(loc='lower right')
    axs[1].set_xlim([0, 1])
    axs[1].set_ylim([0, 1])
    axs[1].set_ylabel('Precision')
    axs[1].set_xlabel('Recall')

    # Horizontal plotting.
    fig2, axs2 = plt.subplots(1, 2, figsize=(12, 4))

    axs2[0].plot(fpr, tpr, label="AUC ROC = " + np.array2string(label1, formatter={'float_kind': lambda x: "%.2f" % x}))
    axs2[0].set_title('ROC Curve')
    axs2[0].legend(loc='lower right')
    axs2[0].plot([0, 1], [0, 1], 'r--')
    axs2[0].set_xlim([0, 1])
    axs2[0].set_ylim([0, 1])
    axs2[0].set_ylabel('True Positive Rate')
    axs2[0].set_xlabel('False Positive Rate')

    axs2[1].set_title('Precision-Recall Curve')
    axs2[1].plot(recall, precision,
                 label="PRC = " + np.array2string(label2, formatter={'float_kind': lambda x: "%.2f" % x}))
    axs2[1].legend(loc='lower right')
    axs2[1].set_xlim([0, 1])
    axs2[1].set_ylim([0, 1])
    axs2[1].set_ylabel('Precision')
    axs2[1].set_xlabel('Recall')

    fig.savefig('metrics/aucroc&prcRepoDBVertical.svg', format='svg', dpi=1200)
    fig2.savefig('metrics/aucroc&prcRepoDBHorizontal.svg', format='svg', dpi=1200)
    plt.close(fig)
    plt.close(fig2)
    plot_dist()


"""
It generates the metrics for the model.
Input:
    model: Model to generate metrics of.
Output:
    Area Under the ROC and PR curve.
"""
def metrics(model):
    hetero, eids = constructor.DISNETHeterograph(full=False, withoutRepoDB=True)
    dataset = GraphDataset(
        [hetero],
        task='link_pred',
        edge_train_mode='disjoint',
        edge_message_ratio=0.8
    )
    toInfer = DataLoader(
        dataset, collate_fn=Batch.collate(), batch_size=1
    )

    model = model.to(device)
    model.eval()

    print("Started getting repoDB predictions at", datetime.now().strftime("%H:%M:%S"))
    _, preds = getDecode(model, eids, toInfer)
    print("Finished getting repoDB predictions at", datetime.now().strftime("%H:%M:%S"))

    print("Started getting random predictions at", datetime.now().strftime("%H:%M:%S"))
    _, predsN = getDecode(model, randomEids(), toInfer, 'R')
    print("Finished getting random predictions at", datetime.now().strftime("%H:%M:%S"))

    labels1 = torch.ones(len(preds))
    labels2 = torch.zeros(len(predsN))

    # Join real and random edge results in one list to calculate metrics.
    pure_predictions = [item for sublist in [preds, predsN] for item in sublist]
    labels = torch.tensor([item for sublist in [labels1, labels2] for item in sublist])

ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
150
    fpr, tpr, label1 = plot_roc(labels, pure_predictions, ('disorder', 'dis_dru_the', 'drug'), "dmsr/",
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
151
                                "RepoDB")
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
152
    recall, precision, label2 = plot_prc(labels, pure_predictions, ('disorder', 'dis_dru_the', 'drug'), "dmsr/",
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
153 154 155 156 157 158 159 160 161 162 163
                                         "RepoDB")

    plotMetrics(fpr, tpr, label1, recall, precision, label2)
    return label1, label2


if __name__ == '__main__':
    # Metrics list for each model.
    rocL, prcL = np.array([]), np.array([])
    # Number of iterations.
    k = 50
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
164

ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
165
    # Set of hyperparameters.
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
166
    epochs = 2343
167 168 169 170
    hidden_dim = 31
    lr = 0.0010235455088934942
    weight_decay = 0.005144745056173074
    dropout = 0.5
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
171

ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
172 173 174 175 176 177
    # Train and test k models and obtain their metrics.
    for i in range(k):
        model = main(epochs, hidden_dim, lr, weight_decay, dropout)
        roc1, prc1 = metrics(model)
        rocL = np.append(rocL, roc1)
        prcL = np.append(prcL, prc1)
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
178

ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
179 180 181
    # Average of the metrics of all the generated models.
    rocM = sum(rocL) / k
    prcM = sum(prcL) / k
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
182

ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
183 184 185 186 187 188 189 190 191
    # Obtain confidence intervals, if number of samples is under 30 t-distribution is used, if over 30 the normal
    # distribution is used.
    if k < 30:
        r = st.t.interval(0.95, k - 1, loc=np.mean(rocL), scale=st.sem(rocL))
        p = st.t.interval(0.95, k - 1, loc=np.mean(prcL), scale=st.sem(prcL))
    else:
        r = st.norm.interval(0.95, loc=np.mean(rocL), scale=st.sem(rocL))
        p = st.norm.interval(0.95, loc=np.mean(prcL), scale=st.sem(prcL))

ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
192 193
    print("AUCROC:  ", rocM, "+-", rocM - r[0])
    print("AUCPR:   ", prcM, "+-", prcM - p[0])