import numpy as np import torch from dmsr import main from deepsnap.batch import Batch from torch.utils.data import DataLoader import heterograph_construction from deepsnap.dataset import GraphDataset from datetime import datetime from utilities import plot_roc, plot_prc, plot_dist import matplotlib.pyplot as plt import scipy.stats as st device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # It defines whether to execute on cpu or gpu. constructor = heterograph_construction.DISNETConstructor(device=device) # Graph constructor. toStudy = 'dis_dru_the' # Graph edge type to study. """ It gets the predictions of the model and decodes them. Input: model: Model to generate predictions. eid: Edges to predict. dataloader: Graph. random: Whether the edges to predict are random or not. Output: Dataframe containing all the predictions ordered and decoded. """ def getDecode(model, eid, dataloader, random=''): print(" Looking for new edges.") for batch in dataloader: batch.to(device) preds = model.pred(batch, eid) new = [] for i, pred in enumerate(preds): new.append([eid[0][i].item(), eid[1][i].item(), pred.cpu().detach().numpy().item()]) n = len(preds) print(" Decoding predictions, this may take a while.") return constructor.decodePredictions(new, toStudy, n, True, random), torch.tensor(preds).cpu().detach() """ It generates random edges in the graph. Output: Randomly generated edges. """ def randomEids(): tensor1 = torch.randint(0, 30729, (5013,), device=torch.device(device)) tensor2 = torch.randint(0, 3944, (5013,), device=torch.device(device)) return (tensor1, tensor2) """ It plots the metrics for the results of the real edge and the random edge set. It joins them vertically and horizontally. Input: fpr: False positve rate. tpr: True positive rate. label1: Area Under the ROC curve. recall: Recall. precision: Precision. label2: Area Under the PR curve. """ def plotMetrics(fpr, tpr, label1, recall, precision, label2): # Vertical plotting. fig, axs = plt.subplots(2, figsize=(6, 10)) axs[0].plot(fpr, tpr, label="AUC ROC = " + np.array2string(label1, formatter={'float_kind': lambda x: "%.2f" % x})) axs[0].set_title('ROC Curve') axs[0].legend(loc='lower right') axs[0].plot([0, 1], [0, 1], 'r--') axs[0].set_xlim([0, 1]) axs[0].set_ylim([0, 1]) axs[0].set_ylabel('True Positive Rate') axs[0].set_xlabel('False Positive Rate') axs[1].set_title('Precision-Recall Curve') axs[1].plot(recall, precision, label="PRC = " + np.array2string(label2, formatter={'float_kind': lambda x: "%.2f" % x})) axs[1].legend(loc='lower right') axs[1].set_xlim([0, 1]) axs[1].set_ylim([0, 1]) axs[1].set_ylabel('Precision') axs[1].set_xlabel('Recall') # Horizontal plotting. fig2, axs2 = plt.subplots(1, 2, figsize=(12, 4)) axs2[0].plot(fpr, tpr, label="AUC ROC = " + np.array2string(label1, formatter={'float_kind': lambda x: "%.2f" % x})) axs2[0].set_title('ROC Curve') axs2[0].legend(loc='lower right') axs2[0].plot([0, 1], [0, 1], 'r--') axs2[0].set_xlim([0, 1]) axs2[0].set_ylim([0, 1]) axs2[0].set_ylabel('True Positive Rate') axs2[0].set_xlabel('False Positive Rate') axs2[1].set_title('Precision-Recall Curve') axs2[1].plot(recall, precision, label="PRC = " + np.array2string(label2, formatter={'float_kind': lambda x: "%.2f" % x})) axs2[1].legend(loc='lower right') axs2[1].set_xlim([0, 1]) axs2[1].set_ylim([0, 1]) axs2[1].set_ylabel('Precision') axs2[1].set_xlabel('Recall') fig.savefig('metrics/aucroc&prcRepoDBVertical.svg', format='svg', dpi=1200) fig2.savefig('metrics/aucroc&prcRepoDBHorizontal.svg', format='svg', dpi=1200) plt.close(fig) plt.close(fig2) plot_dist() """ It generates the metrics for the model. Input: model: Model to generate metrics of. Output: Area Under the ROC and PR curve. """ def metrics(model): hetero, eids = constructor.DISNETHeterograph(full=False, withoutRepoDB=True) dataset = GraphDataset( [hetero], task='link_pred', edge_train_mode='disjoint', edge_message_ratio=0.8 ) toInfer = DataLoader( dataset, collate_fn=Batch.collate(), batch_size=1 ) model = model.to(device) model.eval() print("Started getting repoDB predictions at", datetime.now().strftime("%H:%M:%S")) _, preds = getDecode(model, eids, toInfer) print("Finished getting repoDB predictions at", datetime.now().strftime("%H:%M:%S")) print("Started getting random predictions at", datetime.now().strftime("%H:%M:%S")) _, predsN = getDecode(model, randomEids(), toInfer, 'R') print("Finished getting random predictions at", datetime.now().strftime("%H:%M:%S")) labels1 = torch.ones(len(preds)) labels2 = torch.zeros(len(predsN)) # Join real and random edge results in one list to calculate metrics. pure_predictions = [item for sublist in [preds, predsN] for item in sublist] labels = torch.tensor([item for sublist in [labels1, labels2] for item in sublist]) fpr, tpr, label1 = plot_roc(labels, pure_predictions, ('disorder', 'dis_dru_the', 'drug'), "dmsr/", "RepoDB") recall, precision, label2 = plot_prc(labels, pure_predictions, ('disorder', 'dis_dru_the', 'drug'), "dmsr/", "RepoDB") plotMetrics(fpr, tpr, label1, recall, precision, label2) return label1, label2 if __name__ == '__main__': # Metrics list for each model. rocL, prcL = np.array([]), np.array([]) # Number of iterations. k = 50 # Set of hyperparameters. epochs = 2343 hidden_dim = 31 lr = 0.0010235455088934942 weight_decay = 0.005144745056173074 dropout = 0.5 # Train and test k models and obtain their metrics. for i in range(k): model = main(epochs, hidden_dim, lr, weight_decay, dropout) roc1, prc1 = metrics(model) rocL = np.append(rocL, roc1) prcL = np.append(prcL, prc1) # Average of the metrics of all the generated models. rocM = sum(rocL) / k prcM = sum(prcL) / k # Obtain confidence intervals, if number of samples is under 30 t-distribution is used, if over 30 the normal # distribution is used. if k < 30: r = st.t.interval(0.95, k - 1, loc=np.mean(rocL), scale=st.sem(rocL)) p = st.t.interval(0.95, k - 1, loc=np.mean(prcL), scale=st.sem(prcL)) else: r = st.norm.interval(0.95, loc=np.mean(rocL), scale=st.sem(rocL)) p = st.norm.interval(0.95, loc=np.mean(prcL), scale=st.sem(prcL)) print("AUCROC: ", rocM, "+-", rocM - r[0]) print("AUCPR: ", prcM, "+-", prcM - p[0])