import torch from dmsr import HeteroGNN, generate_convs_link_pred_layers from deepsnap.batch import Batch from torch.utils.data import DataLoader import heterograph_construction from deepsnap.dataset import GraphDataset from datetime import datetime from deepsnap.hetero_gnn import HeteroSAGEConv device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # It defines whether to execute on cpu or gpu. device = 'cpu' constructor = heterograph_construction.DISNETConstructor(device=device) # Graph constructor. edge = ('disorder', 'dis_dru_the', 'drug') # Graph edge type to study. n = 200000 # Number of new predictions. """ It filters the predictions so just new predictions are considered. Input: original: Graph. pred: Edge predictions. """ def filterPreds(original, pred): headsO = original.edge_index[edge][0, :].long() # Heads of the original edges of the graph. new = [] for i, elem in enumerate(pred): pred_labels = pred[edge, i] head = i tail = torch.arange(0, len(pred_labels)) # All tails. indexH = ((headsO == head).nonzero(as_tuple=True)[0]) # Index of those heads originally present in the graph. # Check for t in original.edge_index[edge][1, indexH]: tail = tail[tail != t] # Just get those tails not present in the original graph. new.append([head, tail, pred_labels[tail].cpu().detach().numpy()]) # New predictions are appended. return new """ It gets the top n predictions of the model and decodes them. Input: model: Model to generate predictions. dataloader: Graph. n: Number of predictions.. Output: Dataframe containing the top n predictions ordered and decoded. """ def getTopN(model, dataloader, n): print(" Looking for new edges.") for batch in zip(dataloader): batch = batch[0] pred = model.predict_all(batch) # Predict all edges. new = filterPreds(batch, pred) # Filter those edges present in the original graph. print(" Decoding predictions, this may take a while.") return constructor.decodePredictions(new, edge[1], n) """ It gets the heterograph object and its conversion to dataloader. Output: The heterograph and its dataloader. """ def getOriginal(): hetero, _ = constructor.DISNETHeterograph(full=True, withoutRepoDB=False) dataset = GraphDataset( [hetero], task='link_pred', edge_train_mode='disjoint', edge_message_ratio=0.8 ) dataset_loader = DataLoader( dataset, collate_fn=Batch.collate(), batch_size=1 ) return dataset_loader, hetero """ It wraps all the necessary calls to get the top n predictions of the DMSR model. """ def dmsr(): # Necessary instantiations. original, hetero = getOriginal() convs = generate_convs_link_pred_layers(hetero, HeteroSAGEConv, 31) model = HeteroGNN(convs, hetero, 31, 0.5).to(device) # Load and prepare for inference model. model.load_state_dict(torch.load("./models/dmsrC", map_location=torch.device(device))) model = model.to(device) model.eval() # Get top n. print("Started getting top", n, "at", datetime.now().strftime("%H:%M:%S")) _ = getTopN(model, original, n) print("Finished getting top", n, "at", datetime.now().strftime("%H:%M:%S")) if __name__ == '__main__': dmsr()