import copy from sklearn.metrics import roc_auc_score from utilities import plot_roc, plot_prc import torch import torch.nn as nn from torch.utils.data import DataLoader from matplotlib import pyplot as plt import heterograph_construction from datetime import datetime from deepsnap.dataset import GraphDataset from deepsnap.batch import Batch from deepsnap.hetero_gnn import ( HeteroSAGEConv, HeteroConv, forward_op ) #DeepSnap has some deprecated functions and throws a warning. import warnings warnings.filterwarnings("ignore") edges = [('phenotype', 'dis_dru_the', 'drug')] device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # --------------------------- # FUNCTIONS # --------------------------- def generate_convs_link_pred_layers(hete, conv, hidden_size): convs1 = {} convs2 = {} for message_type in hete.message_types: n_type = message_type[0] s_type = message_type[2] # Attention: This should be changed to n_feat_dim = hete.num_node_features(n_type) and s_feat_dim = hete.num_node_features(s_type) # If features are not an identity matrix. n_feat_dim = hete.num_node_features(n_type) s_feat_dim = hete.num_node_features(s_type) convs1[message_type] = conv(n_feat_dim, hidden_size, s_feat_dim) convs2[message_type] = conv(hidden_size, hidden_size) return [convs1, convs2] class HeteroGNN(torch.nn.Module): def __init__(self, convs, hetero, hidden_size, dropout): super(HeteroGNN, self).__init__() self.convs1 = HeteroConv(convs[0]) # Wrap the heterogeneous GNN layers self.convs2 = HeteroConv(convs[1]) self.bns1 = nn.ModuleDict() self.bns2 = nn.ModuleDict() self.dropout1 = nn.ModuleDict() self.relus1 = nn.ModuleDict() self.loss_fn = torch.nn.BCEWithLogitsLoss() for node_type in hetero.node_types: self.bns1[node_type] = nn.BatchNorm1d(hidden_size) self.bns2[node_type] = nn.BatchNorm1d(hidden_size) self.dropout1[node_type] = nn.Dropout(p=dropout) self.relus1[node_type] = nn.LeakyReLU() def getEmbeddings(self, data, training = True): x = data.node_feature edge_index = data.edge_index edge_weight = data.edge_feature keys = [key for key in edge_weight] for key in keys: newKey = key[1] edge_weight[newKey] = edge_weight[key] del edge_weight[key] x = self.convs1(x, edge_index, edge_weight) x = forward_op(x, self.bns1) x = forward_op(x, self.relus1) x = forward_op(x, self.dropout1) x = self.convs2(x, edge_index, edge_weight) x = forward_op(x, self.bns2) return x def forward(self, data): x = self.getEmbeddings(data) pred = {} pred2 = {} for message_type in edges: nodes_first = None nodes_second = None if message_type == ('phenotype', 'dis_dru_the', 'drug'): nodes_first = torch.index_select(x['phenotype'], 0, data.edge_label_index[message_type][0, :].long()) nodes_second = torch.index_select(x['drug'], 0, data.edge_label_index[message_type][1, :].long()) elif message_type == ('drug', 'dru_dis_the', 'phenotype'): nodes_first = torch.index_select(x['drug'], 0, data.edge_label_index[message_type][0, :].long()) nodes_second = torch.index_select(x['phenotype'], 0, data.edge_label_index[message_type][1, :].long()) elif message_type == ('phenotype', 'dse_sym', 'phenotype'): nodes_first = torch.index_select(x['phenotype'], 0, data.edge_label_index[message_type][0, :].long()) nodes_second = torch.index_select(x['phenotype'], 0, data.edge_label_index[message_type][1, :].long()) elif message_type == ('phenotype', 'sym_dse', 'phenotype'): nodes_first = torch.index_select(x['phenotype'], 0, data.edge_label_index[message_type][0, :].long()) nodes_second = torch.index_select(x['phenotype'], 0, data.edge_label_index[message_type][1, :].long()) pred[message_type] = torch.sigmoid(torch.sum(nodes_first * nodes_second, dim=-1)) return pred def predict_all(self, data): x = self.getEmbeddings(data) pred = {} for message_type in edges: nodes_first = None nodes_second = None if message_type == ('phenotype', 'dis_dru_the', 'drug'): nodes_first = x['phenotype'] nodes_second = x['drug'] for i, elem in enumerate(nodes_first): pred[message_type, i] = torch.sigmoid(torch.sum(elem * nodes_second, dim=-1)) return pred def predict_all_type(self, data, type, id): x = self.getEmbeddings(data) pred = {} for message_type in edges: nodes_first = None nodes_second = None if message_type == ('phenotype', 'dis_dru_the', 'drug') and type == 'phenotype': nodes_first = x['phenotype'][id].unsqueeze(0) nodes_second = x['drug'] elif message_type == ('phenotype', 'dis_dru_the', 'drug') and type == 'drug': nodes_first = x['phenotype'] nodes_second = x['drug'][id].unsqueeze(0) elif message_type == ('phenotype', 'dse_sym', 'phenotype'): nodes_first = x['phenotype'] nodes_second = x['phenotype'] for i, elem in enumerate(nodes_first): pred[message_type, i] = torch.sigmoid(torch.sum(elem * nodes_second, dim=-1)) return pred def pred(self, data, eid): x = self.getEmbeddings(data, False) heads = x['phenotype'] tails = x['drug'] pred = [] for head, tail in zip(eid[0], eid[1]): pred.append(torch.sigmoid(torch.sum(heads[head] * tails[tail], dim=-1))) return pred def loss(self, pred, y): loss = 0 for key in pred: loss += self.loss_fn(pred[key], y[key].type(pred[key].dtype)) return loss def train(model, dataloaders, optimizer, epochs): min_value = 1 best_model = model t_accu = [] v_accu = [] e_accu = [] lossL = [] lossLV = [] criterion = torch.nn.BCEWithLogitsLoss() for epoch in range(1, epochs + 1): for batch in dataloaders['train']: batch.to(device) model.train() optimizer.zero_grad() pred = model(batch) loss = model.loss(pred, batch.edge_label) if epoch == epochs: loss.backward() else: loss.backward(retain_graph=True) optimizer.step() accs = test(model, dataloaders) #log = ' Epoch: {:03d}, Loss: {:.4f}, Train: {:.4f}, Val: {:.4f}, Test: {:.4f}' #print(log.format(epoch, loss.item(), accs['train'], accs['val'], accs['test'])) t_accu.append(accs['train']) v_accu.append(accs['val']) e_accu.append(accs['test']) lossL.append(loss.cpu().detach().numpy()) lossLV.append(accs['valLoss'].cpu().detach().numpy()) if min_value > lossLV[-1] and epoch >= epochs - 200: min_value = lossLV[-1] best_model = copy.deepcopy(model) best_it = epoch torch.cuda.empty_cache() log = 'Best: Iteration {:d} Train: {:.4f}, Val: {:.4f}, Test: {:.4f}' accs = test(best_model, dataloaders) print(log.format(best_it, accs['train'], accs['val'], accs['test'])) plt.plot(lossL) plt.plot(lossLV) plt.plot(best_it,min_value,'k+', linewidth=100) plt.title('Loss Evolution') plt.legend(['Train', 'Validation','Best Model']) plt.ylabel('Loss') plt.xlabel('Iteration') plt.xticks(range(0, epochs+1, int(epochs/10))) plt.yticks((0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85)) plt.savefig('plots/behor/metrics/loss.svg', format='svg', dpi=1200) plt.clf() return t_accu, v_accu, e_accu, best_model def test(model, dataloaders): model.eval() accs = {} for mode, dataloader in dataloaders.items(): acc = 0 num = 0 for batch in dataloader: batch.to(device) pred = model(batch) loss = model.loss(pred, batch.edge_label) for key in pred: pred = pred[key].flatten().cpu().detach().numpy() label = batch.edge_label[key].flatten().cpu().numpy() acc += roc_auc_score(label, pred) num += 1 accs[mode] = acc / num accs[mode+'Loss'] = loss return accs def test2(model, test_loader): true_labels = [] keys = [] pure_pred_labels = [] for batch in test_loader: batch.to(device) pred = model(batch) for key in pred: p = pred[key].cpu().detach().numpy() pure_pred_labels.append(p) true_labels.append(batch.edge_label[key].cpu().detach().numpy()) keys.append(key) return pure_pred_labels, true_labels, keys def main(epochs, hidden_dim, lr, weight_decay, dropout): constructor = heterograph_construction.DISNETConstructor(device='cuda') hetero, _ = constructor.DISNETHeterographDeepSnap(full=False, withoutRepoDB=True) edge_train_mode = 'disjoint' print('edge train mode: {}'.format(edge_train_mode)) dataset = GraphDataset( [hetero], task='link_pred', edge_train_mode=edge_train_mode, edge_message_ratio=0.8, edge_negative_sampling_ratio=1 ) dataset_train, dataset_val, dataset_test = dataset.split( transductive=True, split_ratio=[0.8, 0.1, 0.1], shuffle=True ) train_loader = DataLoader( dataset_train, collate_fn=Batch.collate(), batch_size=1 ) val_loader = DataLoader( dataset_val, collate_fn=Batch.collate(), batch_size=1 ) test_loader = DataLoader( dataset_test, collate_fn=Batch.collate(), batch_size=1 ) dataloaders = { 'train': train_loader, 'val': val_loader, 'test': test_loader } convs = generate_convs_link_pred_layers(hetero, HeteroSAGEConv, hidden_dim) model = HeteroGNN(convs, hetero, hidden_dim, dropout).to(device) optimizer = torch.optim.Adam( model.parameters(), lr=lr, weight_decay=weight_decay) print("Started training at", datetime.now().strftime("%H:%M:%S")) _, _, _, model = train(model, dataloaders, optimizer, epochs) print("Finished training at", datetime.now().strftime("%H:%M:%S")) torch.save(model.state_dict(), "models/behor") # Testing model.eval() print("Started testing at", datetime.now().strftime("%H:%M:%S")) pure_pred_labels, true_labels, keys = test2(model, test_loader) print("Finished testing at", datetime.now().strftime("%H:%M:%S")) labels = [item for sublist in true_labels for item in sublist] pure_predictions = [item for sublist in pure_pred_labels for item in sublist] plot_roc(labels, pure_predictions, keys[0], "behor/") plot_prc(torch.tensor(labels), pure_predictions, keys[0], "behor/") return model if __name__ == '__main__': main(2752, 107, 0.0008317, 0.006314, 0.8)