# -*- coding: utf-8 -*- """ Created on Wed Sep 29 10:44:43 2021 """ import copy from utilities import plot_roc, plot_prc import numpy as np import torch import torch.nn as nn from torch.utils.data import DataLoader from matplotlib import pyplot as plt import heterograph_construction from datetime import datetime from deepsnap.dataset import GraphDataset from deepsnap.batch import Batch from deepsnap.hetero_gnn import ( HeteroSAGEConv, HeteroConv, forward_op ) edges = [('disorder', 'dis_dru_the', 'drug')] # It sets the edges that will be studied. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # It defines whether to execute on cpu or gpu. # --------------------------- # FUNCTIONS # --------------------------- """ It instantiates the convolutional layers for the given graph and dimensions. A convolutional layer should be instantiated for every type of edge in the graph. Input: hete: Graph. conv: Type of convolution to be applied. hidden_size: Number of hidden dimensions of the convolutions. Output: Instantiated convolutional layers. """ def generate_convs_link_pred_layers(hete, conv, hidden_size): convs1 = {} convs2 = {} for message_type in hete.message_types: n_type = message_type[0] s_type = message_type[2] n_feat_dim = hete.num_node_features(n_type) s_feat_dim = hete.num_node_features(s_type) convs1[message_type] = conv(n_feat_dim, hidden_size, s_feat_dim) convs2[message_type] = conv(hidden_size, hidden_size, hidden_size) return [convs1, convs2] """ Heterogeneous Graph Neural Network. """ class HeteroGNN(torch.nn.Module): """ It instantiates the HeteroGNN and all its layers (convolutions, leaky ReLU, batch normalisation, dropout and loss). Input: convs: Convolutional layers. hetero: Graph. hidden_size: Number of hidden dimensions of the convolutions. dropout: Dropout rate. """ def __init__(self, convs, hetero, hidden_size, dropout): super(HeteroGNN, self).__init__() self.convs1 = HeteroConv(convs[0]) # Wrap the heterogeneous GNN layers for the edge types in one. self.convs2 = HeteroConv(convs[1]) self.loss_fn = torch.nn.BCEWithLogitsLoss() self.bns1 = nn.ModuleDict() self.bns2 = nn.ModuleDict() self.dropout1 = nn.ModuleDict() self.relus1 = nn.ModuleDict() for node_type in hetero.node_types: self.bns1[node_type] = torch.nn.BatchNorm1d(hidden_size) self.bns2[node_type] = torch.nn.BatchNorm1d(hidden_size) self.relus1[node_type] = nn.LeakyReLU() self.dropout1[node_type] = nn.Dropout(p=dropout) """ It generates embeddings for a given graph applying all the instantiated layers in the previous function. Input: data: Graph. Output: Node embeddings. """ def getEmbeddings(self, data): x = data.node_feature edge_index = data.edge_index edge_weight = data.edge_feature # Edge weights need to be reformated to match model's requirements. keys = [key for key in edge_weight] for key in keys: newKey = key[1] edge_weight[newKey] = edge_weight[key] del edge_weight[key] x = self.convs1(x, edge_index, edge_weight) x = forward_op(x, self.bns1) x = forward_op(x, self.relus1) x = forward_op(x, self.dropout1) x = self.convs2(x, edge_index, edge_weight) x = forward_op(x, self.bns2) return x """ It generates predictions of edge probability in a graph. Input: data: Graph Output: Edge predictions. """ def forward(self, data): x = self.getEmbeddings(data) pred = {} for message_type in edges: nodes_first = None nodes_second = None if message_type == ('disorder', 'dis_dru_the', 'drug'): # Selects the embeddings of the needed node type, for the edges of the graph. nodes_first = torch.index_select(x['disorder'], 0, data.edge_label_index[message_type][0, :].long()) nodes_second = torch.index_select(x['drug'], 0, data.edge_label_index[message_type][1, :].long()) # Provided as example of other edge type, not used. elif message_type == ('disorder', 'sym_dse', 'disorder'): nodes_first = torch.index_select(x['disorder'], 0, data.edge_label_index[message_type][0, :].long()) nodes_second = torch.index_select(x['disorder'], 0, data.edge_label_index[message_type][1, :].long()) # Applies dot product and sigmoid to the lists containing the embeddings of the nodes. pred[message_type] = torch.sigmoid(torch.sum(nodes_first * nodes_second, dim=-1)) return pred """ It generates predictions of edge probability in a graph for all the possible edges. Input: data: Graph Output: Edge predictions. """ def predict_all(self, data): x = self.getEmbeddings(data) pred = {} for message_type in edges: nodes_first = None nodes_second = None if message_type == ('disorder', 'dis_dru_the', 'drug'): nodes_first = x['disorder'] nodes_second = x['drug'] # Provided as example of other edge type, not used. elif message_type == ('disorder', 'sym_dse', 'disorder'): nodes_first = x['disorder'] nodes_second = x['disorder'] for i, elem in enumerate(nodes_first): pred[message_type, i] = torch.sigmoid(torch.sum(elem * nodes_second, dim=-1)) return pred """ It generates predictions of edge probability in a graph for all the possible edges for a given type. Input: data: Graph type: Node type to generate all possible edges. id: Node id to target. Output: Edge predictions. """ def predict_all_type(self, data, type, id): x = self.getEmbeddings(data) pred = {} for message_type in edges: nodes_first = None nodes_second = None if message_type == ('disorder', 'dis_dru_the', 'drug') and type == 'disease': nodes_first = x['disorder'][id].unsqueeze(0) nodes_second = x['drug'] elif message_type == ('disorder', 'dis_dru_the', 'drug') and type == 'drug': nodes_first = x['disorder'] nodes_second = x['drug'][id].unsqueeze(0) for i, elem in enumerate(nodes_first): pred[message_type, i] = torch.sigmoid(torch.sum(elem * nodes_second, dim=-1)) return pred """ It generates predictions of edge probability in a graph for a given set of edges Input: data: Graph. eid: Edge id to generate predictions. Output: Edge predictions. """ def pred(self, data, eid): x = self.getEmbeddings(data) heads = x['disorder'] tails = x['drug'] pred = [] for head, tail in zip(eid[0], eid[1]): pred.append(torch.sigmoid(torch.sum(heads[head] * tails[tail], dim=-1))) return pred """ It calculates the loss for a given set of predictions. Input: pred: Predictions. y: Real labels. Output: Loss. """ def loss(self, pred, y): loss = 0 for key in pred: p = pred[key] loss += self.loss_fn(p, y[key].type(pred[key].dtype)) return loss """ It trains the model. Input: model: Model to train. dataloaders: Dataloaders containing the subgraphs for training, testing and validating. optimizer: Optimizer to train. epochs: Number of epochs to train the model. Output: List of accuracies of the model for all the dataloaders for every epoch. """ def train(model, dataloaders, optimizer, epochs): val_max = 0 best_model = model t_accu = [] v_accu = [] e_accu = [] lossL = [] lossLV = [] for epoch in range(1, epochs + 1): for iter_i, batch in enumerate(dataloaders['train']): batch.to(device) model.train() optimizer.zero_grad() pred = model(batch) loss = model.loss(pred, batch.edge_label) if epoch == epochs: loss.backward() else: loss.backward(retain_graph=True) optimizer.step() log = ' Epoch: {:03d}, Loss: {:.4f}, Train: {:.4f}, Val: {:.4f}, Test: {:.4f}' accs = test(model, dataloaders) t_accu.append(accs['train']) v_accu.append(accs['val']) e_accu.append(accs['test']) lossL.append(loss.cpu().detach().numpy()) lossLV.append(accs['valLoss'].cpu().detach().numpy()) print(log.format(epoch, loss.item(), accs['train'], accs['val'], accs['test'])) if val_max < accs['val']: val_max = accs['val'] best_model = copy.deepcopy(model) torch.cuda.empty_cache() log = 'Best: Train: {:.4f}, Val: {:.4f}, Test: {:.4f}' accs = test(best_model, dataloaders) print(log.format(accs['train'], accs['val'], accs['test'])) plt.plot(lossL) plt.plot(lossLV) plt.legend(['Train', 'Validation']) plt.ylabel('Loss') plt.xlabel('Iteration') plt.xticks(range(0, epochs + 1, int(epochs / 10))) plt.savefig('metrics/lossDMSR.svg', format='svg', dpi=1200) plt.clf() plt.close() return t_accu, v_accu, e_accu """ It tests the model, with a threshold of 0.5 to determine whether a prediction is considered as existance or not. Input: model: Model to train. dataloaders: Dataloaders containing the subgraphs for training, testing and validating. Output: List of accuracies of the model for all the dataloaders and their losses. """ def test(model, dataloaders): model.eval() accs = {} for mode, dataloader in dataloaders.items(): acc = 0 for i, batch in enumerate(dataloader): num = 0 batch.to(device) pred = model(batch) loss = model.loss(pred, batch.edge_label) for key in pred: p = pred[key].cpu().detach().numpy() pred_label = np.zeros_like(p, dtype=np.int64) pred_label[np.where(p > 0.5)[0]] = 1 pred_label[np.where(p <= 0.5)[0]] = 0 acc += np.sum(pred_label == batch.edge_label[key].cpu().numpy()) num += len(pred_label) accs[mode] = acc / num accs[mode + 'Loss'] = loss return accs """ It tests the model. Input: model: Model to train. test_loader: Dataloaders containing the subgraphs for testing. Output: Predictions without applying threshold. True labels of the edges. Keys of the tested edge types. """ def test2(model, test_loader): true_labels = [] keys = [] pure_pred_labels = [] for batch in test_loader: batch.to(device) pred = model(batch) for key in pred: p = pred[key].cpu().detach().numpy() pure_pred_labels.append(p) true_labels.append(batch.edge_label[key].cpu().detach().numpy()) keys.append(key) return pure_pred_labels, true_labels, keys """ It wraps all the functions for training and testing the model. Input: epochs: Number of epochs to train the model. hidden_dim: Number of hidden dimensions of the convolutions. lr: Learning rate for the model's training. weight_decay: Weight decay as regularisation. dropout: Dropout rate. Output: Trained model. """ def main(epochs, hidden_dim, lr, weight_decay, dropout): # Graph construction. constructor = heterograph_construction.DISNETConstructor(device='cuda') hetero, _ = constructor.DISNETHeterograph(full=False, withoutRepoDB=True) # Graph splitting into dataloaders. edge_train_mode = 'disjoint' print('edge train mode: {}'.format(edge_train_mode)) dataset = GraphDataset( [hetero], task='link_pred', edge_train_mode=edge_train_mode, edge_message_ratio=0.8, edge_negative_sampling_ratio=1 ) dataset_train, dataset_val, dataset_test = dataset.split( transductive=True, split_ratio=[0.8, 0.1, 0.1], shuffle=True ) train_loader = DataLoader( dataset_train, collate_fn=Batch.collate(), batch_size=1 ) val_loader = DataLoader( dataset_val, collate_fn=Batch.collate(), batch_size=1 ) test_loader = DataLoader( dataset_test, collate_fn=Batch.collate(), batch_size=1 ) dataloaders = { 'train': train_loader, 'val': val_loader, 'test': test_loader } # Model instantiation convs = generate_convs_link_pred_layers(hetero, HeteroSAGEConv, hidden_dim) model = HeteroGNN(convs, hetero, hidden_dim, dropout).to(device) # Optimizer instantiation optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay) # Training print("Started training at", datetime.now().strftime("%H:%M:%S")) _ = train(model, dataloaders, optimizer, epochs) print("Finished training at", datetime.now().strftime("%H:%M:%S")) torch.save(model.state_dict(), "./models/dmsr") # Testing model.eval() print("Started testing at", datetime.now().strftime("%H:%M:%S")) pure_pred_labels, true_labels, keys = test2(model, test_loader) print("Finished testing at", datetime.now().strftime("%H:%M:%S")) # Metrics calculations. labels = [item for sublist in true_labels for item in sublist] pure_predictions = [item for sublist in pure_pred_labels for item in sublist] plot_roc(labels, pure_predictions, keys[0], "dmsr/") plot_prc(torch.tensor(labels), pure_predictions, keys[0], "dmsr/") return model if __name__ == '__main__': # Set of random hyperparameters to train the model. main(400, 32, 0.01, 1e-4, 0)