import numpy as np import torch from deepSnapPred import HeteroGNN, generate_convs_link_pred_layers, main from deepsnap.batch import Batch from torch.utils.data import DataLoader import argparse import heterograph_construction from deepsnap.dataset import GraphDataset from datetime import datetime from deepsnap.hetero_gnn import HeteroSAGEConv from utilities import plot_roc, plot_prc, plot_dist import matplotlib.pyplot as plt def arg_parse(): parser = argparse.ArgumentParser(description='Link pred arguments.') parser.add_argument('--device', type=str, help='CPU / GPU device.') parser.set_defaults( device= 'cuda' if torch.cuda.is_available() else 'cpu', ) return parser.parse_args() args = arg_parse() constructor = heterograph_construction.DISNETConstructor(device=args.device) toStudy = 'dis_dru_the' def getTopNDS(model, eid, dataloader, random=''): print(" Looking for new edges.") for batch in dataloader: batch.to(args.device) preds = model.pred(batch, eid) new = [] for i, pred in enumerate(preds): new.append([eid[0][i].item(), eid[1][i].item(), pred.cpu().detach().numpy().item()]) n = len(preds) print(" Decoding predictions, this may take a while.") return constructor.decodePredictions(new, toStudy, n, True,random), torch.tensor(preds).cpu().detach() def randomEids(): tensor1 = torch.randint(0, 30729, (5013,),device=torch.device(args.device)) tensor2 = torch.randint(0, 3944, (5013,),device=torch.device(args.device)) return (tensor1, tensor2) def plotMetrics(extension, fpr, tpr, label1, recall, precision, label2): fig, axs = plt.subplots(2, figsize=(6, 10)) axs[0].plot(fpr, tpr, label="AUC ROC = " + np.array2string(label1, formatter={'float_kind': lambda x: "%.2f" % x})) axs[0].set_title('ROC Curve') axs[0].legend(loc='lower right') axs[0].plot([0, 1], [0, 1], 'r--') axs[0].set_xlim([0, 1]) axs[0].set_ylim([0, 1]) axs[0].set_ylabel('True Positive Rate') axs[0].set_xlabel('False Positive Rate') axs[1].set_title('Precision-Recall Curve') axs[1].plot(recall, precision, label="PRC = " + np.array2string(label2, formatter={'float_kind': lambda x: "%.2f" % x})) axs[1].legend(loc='lower right') axs[1].set_xlim([0, 1]) axs[1].set_ylim([0, 1]) axs[1].set_ylabel('Precision') axs[1].set_xlabel('Recall') fig2, axs2 = plt.subplots(1, 2, figsize=(12, 4)) axs2[0].plot(fpr, tpr, label="AUC ROC = " + np.array2string(label1, formatter={'float_kind': lambda x: "%.2f" % x})) axs2[0].set_title('ROC Curve') axs2[0].legend(loc='lower right') axs2[0].plot([0, 1], [0, 1], 'r--') axs2[0].set_xlim([0, 1]) axs2[0].set_ylim([0, 1]) axs2[0].set_ylabel('True Positive Rate') axs2[0].set_xlabel('False Positive Rate') axs2[1].set_title('Precision-Recall Curve') axs2[1].plot(recall, precision, label="PRC = " + np.array2string(label2, formatter={'float_kind': lambda x: "%.2f" % x})) axs2[1].legend(loc='lower right') axs2[1].set_xlim([0, 1]) axs2[1].set_ylim([0, 1]) axs2[1].set_ylabel('Precision') axs2[1].set_xlabel('Recall') fig.savefig('plots/'+ extension + '/metrics/aucroc&prcRepoDBVertical.svg', format='svg', dpi=1200) fig2.savefig('plots/'+ extension + '/metrics/aucroc&prcRepoDBHorizontal.svg', format='svg', dpi=1200) plot_dist(extension + '/') def deepSnapMetrics(): hetero, eids = constructor.DISNETHeterographDeepSnap(all=True, withoutRepoDB=True) dataset = GraphDataset( [hetero], task='link_pred', edge_train_mode='disjoint', edge_message_ratio=0.8 ) toInfer = DataLoader( dataset, collate_fn=Batch.collate(), batch_size=1 ) hidden_dim = 45 convs = generate_convs_link_pred_layers(hetero, HeteroSAGEConv, hidden_dim) model = HeteroGNN(convs, hetero, hidden_dim).to(args.device) model.load_state_dict(torch.load("./models/modelDeepSnapPred", map_location=torch.device(args.device))) model = model.to(args.device) model.eval() print("Started getting repoDB predictions at", datetime.now().strftime("%H:%M:%S")) topN, preds = getTopNDS(model, eids, toInfer) print("Finished getting repoDB predictions at", datetime.now().strftime("%H:%M:%S")) topN, predsN = getTopNDS(model, randomEids(), toInfer, 'R') labels1 = torch.ones(len(preds)) labels2 = torch.zeros(len(predsN)) pure_predictions = [item for sublist in [preds,predsN] for item in sublist] labels = torch.tensor([item for sublist in [labels1,labels2] for item in sublist]) fpr, tpr, label1 = plot_roc(labels, pure_predictions,('disorder', 'dis_dru_the', 'drug'), "deepSnapPred/", "RepoDB") recall, precision, label2 = plot_prc(labels, pure_predictions, ('disorder', 'dis_dru_the', 'drug'), "deepSnapPred/", "RepoDB") plotMetrics("deepSnapPred", fpr, tpr, label1, recall, precision, label2) if __name__ == '__main__': main() deepSnapMetrics()