import torch from deepSnapPred import HeteroGNN, generate_convs_link_pred_layers from deepsnap.batch import Batch from torch.utils.data import DataLoader import heterograph_construction from deepsnap.dataset import GraphDataset from datetime import datetime from deepsnap.hetero_gnn import HeteroSAGEConv device = 'cuda' constructor = heterograph_construction.DISNETConstructor(device=device) toStudy = 'dis_dru_the' n = 1479774 #Number of new predictions, top n. hidden_dim = 107 dropout = 0.8 def filterPreds(original, pred, key): headsO = original.edge_index[key][0, :].long() new = [] for i, elem in enumerate(pred): pred_labels = pred[key, i] head = i tail = torch.arange(0,len(pred_labels)) indexH = ((headsO == head).nonzero(as_tuple=True)[0]) for index in indexH: tail = tail[tail != index] new.append([head, tail, pred_labels[tail].cpu().detach().numpy()]) return new def getTopNDS(model, original, dataloader, key, n): print(" Looking for new edges.") for batch, original in zip(dataloader, original): batch.to(device) pred = model.predict_all(batch) new = filterPreds(original, pred, key) print(" Decoding predictions, this may take a while.") return constructor.decodePredictions(new, toStudy, n) def getOriginal(): hetero, _ = constructor.DISNETHeterographDeepSnap(full=True) dataset = GraphDataset( [hetero], task='link_pred', edge_train_mode='disjoint', edge_message_ratio=0.8 ) dataset_loader = DataLoader( dataset, collate_fn=Batch.collate(), batch_size=1 ) return dataset_loader, hetero def deepSnap(): original, hetero = getOriginal() convs = generate_convs_link_pred_layers(hetero, HeteroSAGEConv, hidden_dim) model = HeteroGNN(convs, hetero, hidden_dim, dropout).to(device).to(device) model.load_state_dict(torch.load("./models/behor", map_location=torch.device(device))) model = model.to(device) model.eval() edge = ('disorder', 'dis_dru_the', 'drug') toInfer, _ = getOriginal() print("Started getting top", n, "at", datetime.now().strftime("%H:%M:%S")) topN = getTopNDS(model, original, toInfer, edge, n) print("Finished getting top",n, "at", datetime.now().strftime("%H:%M:%S")) if __name__ == '__main__': deepSnap()