import numpy as np
import torch
from dmsr import main
from deepsnap.batch import Batch
from torch.utils.data import DataLoader
import heterograph_construction
from deepsnap.dataset import GraphDataset
from datetime import datetime
from utilities import plot_roc, plot_prc, plot_dist
import matplotlib.pyplot as plt

import scipy.stats as st

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')  # It defines whether to execute on cpu or gpu.
constructor = heterograph_construction.DISNETConstructor(device=device)  # Graph constructor.
toStudy = 'dis_dru_the'  # Graph edge type to study.

It gets the predictions of the model and decodes them.
    model: Model to generate predictions.
    eid: Edges to predict.
    dataloader: Graph.
    random: Whether the edges to predict are random or not.
    Dataframe containing all the predictions ordered and decoded.
def getDecode(model, eid, dataloader, random=''):
    print("     Looking for new edges.")
    for batch in dataloader:
        preds = model.pred(batch, eid)
        new = []
        for i, pred in enumerate(preds):
            new.append([eid[0][i].item(), eid[1][i].item(), pred.cpu().detach().numpy().item()])

    n = len(preds)
    print("     Decoding predictions, this may take a while.")
    return constructor.decodePredictions(new, toStudy, n, True, random), torch.tensor(preds).cpu().detach()

It generates random edges in the graph.
    Randomly generated edges.
def randomEids():
    tensor1 = torch.randint(0, 30729, (5013,), device=torch.device(device))
    tensor2 = torch.randint(0, 3944, (5013,), device=torch.device(device))
    return (tensor1, tensor2)

It plots the metrics for the results of the real edge and the random edge set. It joins them vertically and horizontally.
    fpr: False positve rate.
    tpr: True positive rate.
    label1: Area Under the ROC curve.
    recall: Recall.
    precision: Precision.
    label2: Area Under the PR curve.
def plotMetrics(fpr, tpr, label1, recall, precision, label2):
    # Vertical plotting.
    fig, axs = plt.subplots(2, figsize=(6, 10))

    axs[0].plot(fpr, tpr, label="AUC ROC = " + np.array2string(label1, formatter={'float_kind': lambda x: "%.2f" % x}))
    axs[0].set_title('ROC Curve')
    axs[0].legend(loc='lower right')
    axs[0].plot([0, 1], [0, 1], 'r--')
    axs[0].set_xlim([0, 1])
    axs[0].set_ylim([0, 1])
    axs[0].set_ylabel('True Positive Rate')
    axs[0].set_xlabel('False Positive Rate')

    axs[1].set_title('Precision-Recall Curve')
    axs[1].plot(recall, precision,
                label="PRC = " + np.array2string(label2, formatter={'float_kind': lambda x: "%.2f" % x}))
    axs[1].legend(loc='lower right')
    axs[1].set_xlim([0, 1])
    axs[1].set_ylim([0, 1])

    # Horizontal plotting.
    fig2, axs2 = plt.subplots(1, 2, figsize=(12, 4))

    axs2[0].plot(fpr, tpr, label="AUC ROC = " + np.array2string(label1, formatter={'float_kind': lambda x: "%.2f" % x}))
    axs2[0].set_title('ROC Curve')
    axs2[0].legend(loc='lower right')
    axs2[0].plot([0, 1], [0, 1], 'r--')
    axs2[0].set_xlim([0, 1])
    axs2[0].set_ylim([0, 1])
    axs2[0].set_ylabel('True Positive Rate')
    axs2[0].set_xlabel('False Positive Rate')

    axs2[1].set_title('Precision-Recall Curve')
    axs2[1].plot(recall, precision,
                 label="PRC = " + np.array2string(label2, formatter={'float_kind': lambda x: "%.2f" % x}))
    axs2[1].legend(loc='lower right')
    axs2[1].set_xlim([0, 1])
    axs2[1].set_ylim([0, 1])

    fig.savefig('metrics/aucroc&prcRepoDBVertical.svg', format='svg', dpi=1200)
    fig2.savefig('metrics/aucroc&prcRepoDBHorizontal.svg', format='svg', dpi=1200)

It generates the metrics for the model.
    model: Model to generate metrics of.
    Area Under the ROC and PR curve.
def metrics(model):
    hetero, eids = constructor.DISNETHeterograph(full=False, withoutRepoDB=True)
    dataset = GraphDataset(
    toInfer = DataLoader(
        dataset, collate_fn=Batch.collate(), batch_size=1

    model = model.to(device)

    print("Started getting repoDB predictions at", datetime.now().strftime("%H:%M:%S"))
    _, preds = getDecode(model, eids, toInfer)
    print("Finished getting repoDB predictions at", datetime.now().strftime("%H:%M:%S"))

    print("Started getting random predictions at", datetime.now().strftime("%H:%M:%S"))
    _, predsN = getDecode(model, randomEids(), toInfer, 'R')
    print("Finished getting random predictions at", datetime.now().strftime("%H:%M:%S"))

    labels1 = torch.ones(len(preds))
    labels2 = torch.zeros(len(predsN))

    # Join real and random edge results in one list to calculate metrics.
    pure_predictions = [item for sublist in [preds, predsN] for item in sublist]
    labels = torch.tensor([item for sublist in [labels1, labels2] for item in sublist])

    fpr, tpr, label1 = plot_roc(labels, pure_predictions, ('disorder', 'dis_dru_the', 'drug'), "dmsr-f/",
    recall, precision, label2 = plot_prc(labels, pure_predictions, ('disorder', 'dis_dru_the', 'drug'), "dmsr-f/",

    plotMetrics(fpr, tpr, label1, recall, precision, label2)
    return label1, label2

if __name__ == '__main__':
    # Metrics list for each model.
    rocL, prcL = np.array([]), np.array([])
    # Number of iterations.
    k = 50
    # Set of hyperparameters.
    epochs = 2752
    hidden_dim = 107
    lr = 0.0008317
    weight_decay = 0.006314
    dropout = 0.8
    # Train and test k models and obtain their metrics.
    for i in range(k):
        model = main(epochs, hidden_dim, lr, weight_decay, dropout)
        roc1, prc1 = metrics(model)
        rocL = np.append(rocL, roc1)
        prcL = np.append(prcL, prc1)
    # Average of the metrics of all the generated models.
    rocM = sum(rocL) / k
    prcM = sum(prcL) / k
    # Obtain confidence intervals, if number of samples is under 30 t-distribution is used, if over 30 the normal
    # distribution is used.
    if k < 30:
        r = st.t.interval(0.95, k - 1, loc=np.mean(rocL), scale=st.sem(rocL))
        p = st.t.interval(0.95, k - 1, loc=np.mean(prcL), scale=st.sem(prcL))
        r = st.norm.interval(0.95, loc=np.mean(rocL), scale=st.sem(rocL))
        p = st.norm.interval(0.95, loc=np.mean(prcL), scale=st.sem(prcL))

    print("AUCROC:  ", rocM, "+-", rocM-r[0])
    print("AUCPR:   ", prcM, "+-", prcM-p[0])