import torch from deepSnapPred import HeteroGNN, generate_convs_link_pred_layers from deepsnap.batch import Batch from torch.utils.data import DataLoader import argparse import heterograph_construction from deepsnap.dataset import GraphDataset from datetime import datetime from deepsnap.hetero_gnn import HeteroSAGEConv def arg_parse(): parser = argparse.ArgumentParser(description='Link pred arguments.') parser.add_argument('--device', type=str, help='CPU / GPU device.') parser.add_argument('--n', type=int, help='Number of predictions.') parser.set_defaults( device='cuda' if torch.cuda.is_available() else 'cpu', n=6 ) return parser.parse_args() args = arg_parse() constructor = heterograph_construction.DISNETConstructor(device=args.device) toStudy = 'dis_dru_the' n = args.n #Number of new predictions, top n. def filterPreds(original, pred, key): headsO = original.edge_index[key][0, :].long() new = [] for i, elem in enumerate(pred): pred_labels = pred[key, i] head = i tail = torch.arange(0,len(pred_labels)) indexH = ((headsO == head).nonzero(as_tuple=True)[0]) for index in indexH: tail = tail[tail != index] new.append([head, tail, pred_labels[tail].cpu().detach().numpy()]) return new def getTopNDS(model, original, dataloader, key, n): print(" Looking for new edges.") for batch, original in zip(dataloader, original): batch.to(args.device) pred = model.predict_all(batch) new = filterPreds(original, pred, key) print(" Decoding predictions, this may take a while.") return constructor.decodePredictions(new, toStudy, n) def getOriginal(): hetero = constructor.DISNETHeterographDeepSnap(full=True) dataset = GraphDataset( [hetero], task='link_pred', edge_train_mode='disjoint', edge_message_ratio=0.8 ) dataset_loader = DataLoader( dataset, collate_fn=Batch.collate(), batch_size=1 ) return dataset_loader, hetero def deepSnap(): original, hetero = getOriginal() conv1, conv2 = generate_convs_link_pred_layers(hetero, HeteroSAGEConv, 32) model = HeteroGNN(conv1, conv2, hetero, 32).to(args.device) model.load_state_dict(torch.load("./models/modelDeepSnapPred", map_location=torch.device(args.device))) model = model.to(args.device) model.eval() edge = ('disorder', 'dis_dru_the', 'drug') toInfer, _ = getOriginal() print("Started getting top", n, "at", datetime.now().strftime("%H:%M:%S")) topN = getTopNDS(model, original, toInfer, edge, n) print("Finished getting top",n, "at", datetime.now().strftime("%H:%M:%S")) if __name__ == '__main__': deepSnap()