import numpy as np import torch from deepSnapPred import main from deepsnap.batch import Batch from torch.utils.data import DataLoader import heterograph_construction from deepsnap.dataset import GraphDataset from datetime import datetime from utilities import plot_roc, plot_prc, plot_dist import matplotlib.pyplot as plt import scipy.stats as st device = 'cuda' constructor = heterograph_construction.DISNETConstructor(device=device) toStudy = 'dis_dru_the' def getTopNDS(model, eid, dataloader, random=''): print(" Looking for new edges.") for batch in dataloader: batch.to(device) preds = model.pred(batch, eid) new = [] for i, pred in enumerate(preds): new.append([eid[0][i].item(), eid[1][i].item(), pred.cpu().detach().numpy().item()]) n = len(preds) print(" Decoding predictions, this may take a while.") return constructor.decodePredictions(new, toStudy, n, True,random), torch.tensor(preds).cpu().detach() def randomEids(): tensor1 = torch.randint(0, 30729, (5013,),device=torch.device(device)) tensor2 = torch.randint(0, 3944, (5013,),device=torch.device(device)) return (tensor1, tensor2) def plotMetrics(extension, fpr, tpr, label1, recall, precision, label2): fig, axs = plt.subplots(2, figsize=(6, 10)) axs[0].plot(fpr, tpr, label="AUC ROC = " + np.array2string(label1, formatter={'float_kind': lambda x: "%.2f" % x})) axs[0].set_title('ROC Curve') axs[0].legend(loc='lower right') axs[0].plot([0, 1], [0, 1], 'r--') axs[0].set_xlim([0, 1]) axs[0].set_ylim([0, 1]) axs[0].set_ylabel('True Positive Rate') axs[0].set_xlabel('False Positive Rate') axs[1].set_title('Precision-Recall Curve') axs[1].plot(recall, precision, label="PRC = " + np.array2string(label2, formatter={'float_kind': lambda x: "%.2f" % x})) axs[1].legend(loc='lower right') axs[1].set_xlim([0, 1]) axs[1].set_ylim([0, 1]) axs[1].set_ylabel('Precision') axs[1].set_xlabel('Recall') fig2, axs2 = plt.subplots(1, 2, figsize=(12, 4)) axs2[0].plot(fpr, tpr, label="AUC ROC = " + np.array2string(label1, formatter={'float_kind': lambda x: "%.2f" % x})) axs2[0].set_title('ROC Curve') axs2[0].legend(loc='lower right') axs2[0].plot([0, 1], [0, 1], 'r--') axs2[0].set_xlim([0, 1]) axs2[0].set_ylim([0, 1]) axs2[0].set_ylabel('True Positive Rate') axs2[0].set_xlabel('False Positive Rate') axs2[1].set_title('Precision-Recall Curve') axs2[1].plot(recall, precision, label="PRC = " + np.array2string(label2, formatter={'float_kind': lambda x: "%.2f" % x})) axs2[1].legend(loc='lower right') axs2[1].set_xlim([0, 1]) axs2[1].set_ylim([0, 1]) axs2[1].set_ylabel('Precision') axs2[1].set_xlabel('Recall') fig.savefig('plots/'+ extension + '/metrics/aucroc&prcRepoDBVertical.svg', format='svg', dpi=1200) fig2.savefig('plots/'+ extension + '/metrics/aucroc&prcRepoDBHorizontal.svg', format='svg', dpi=1200) plot_dist(extension + '/') def deepSnapMetrics(model): hetero, eids = constructor.DISNETHeterographDeepSnap(full=False, withoutRepoDB=True) dataset = GraphDataset( [hetero], task='link_pred', edge_train_mode='disjoint', edge_message_ratio=0.8 ) toInfer = DataLoader( dataset, collate_fn=Batch.collate(), batch_size=1 ) model = model.to(device) model.eval() print("Started getting repoDB predictions at", datetime.now().strftime("%H:%M:%S")) topN, preds = getTopNDS(model, eids, toInfer) print("Finished getting repoDB predictions at", datetime.now().strftime("%H:%M:%S")) topN, predsN = getTopNDS(model, randomEids(), toInfer, 'R') labels1 = torch.ones(len(preds)) labels2 = torch.zeros(len(predsN)) pure_predictions = [item for sublist in [preds,predsN] for item in sublist] labels = torch.tensor([item for sublist in [labels1,labels2] for item in sublist]) fpr, tpr, label1 = plot_roc(labels, pure_predictions,('disorder', 'dis_dru_the', 'drug'), "behor/", "RepoDB") recall, precision, label2 = plot_prc(labels, pure_predictions, ('disorder', 'dis_dru_the', 'drug'), "behor/", "RepoDB") plotMetrics("behor", fpr, tpr, label1, recall, precision, label2) return label1, label2 if __name__ == '__main__': rocL, prcL = np.array([]), np.array([]) k = 20 epochs = 2752 hidden_dim = 107 lr = 0.0008317 weight_decay = 0.006314 dropout = 0.8 for i in range(k): model = main(epochs, hidden_dim, lr, weight_decay, dropout) roc1, prc1 = deepSnapMetrics(model) rocL = np.append(rocL, roc1) prcL = np.append(prcL, prc1) rocM = sum(rocL) / k prcM = sum(prcL) / k r = st.t.interval(0.95, k - 1, loc=np.mean(rocL), scale=st.sem(rocL)) p = st.t.interval(0.95, k - 1, loc=np.mean(prcL), scale=st.sem(prcL)) print("AUCROC: ", rocM, "+-", rocM-r[0]) print("AUCPR: ", prcM, "+-", prcM-p[0])