import argparse import copy from sklearn.metrics import roc_auc_score from utilities import plot_roc, plot_prc import torch import torch.nn as nn from torch.utils.data import DataLoader from matplotlib import pyplot as plt import heterograph_construction from datetime import datetime from deepsnap.dataset import GraphDataset from deepsnap.batch import Batch from deepsnap.hetero_gnn import ( HeteroSAGEConv, HeteroConv, forward_op ) #DeepSnap has some deprecated functions and throws a warning. import warnings warnings.filterwarnings("ignore") edges = [('disorder', 'dis_dru_the', 'drug')] # --------------------------- # FUNCTIONS # --------------------------- def arg_parse(): parser = argparse.ArgumentParser(description='Link pred arguments.') parser.add_argument('--device', type=str, help='CPU / GPU device.') parser.add_argument('--epochs', type=int, help='Number of epochs to train.') parser.add_argument('--mode', type=str, help='Link prediction mode. Disjoint or all.') parser.add_argument('--edge_message_ratio', type=float, help='Ratio of edges used for message-passing (only in disjoint mode).') parser.add_argument('--hidden_dim', type=list, help='Hidden dimension of GNN.') parser.add_argument('--lr', type=float, help='The learning rate.') parser.add_argument('--weight_decay', type=float, help='Weight decay.') parser.set_defaults( epochs=10000, device='cuda', mode="disjoint", edge_message_ratio=0.8, hidden_dim=45, lr=0.001, weight_decay=0.001, ) return parser.parse_args() def generate_convs_link_pred_layers(hete, conv, hidden_size): convs1 = {} convs2 = {} for message_type in hete.message_types: n_type = message_type[0] s_type = message_type[2] # Attention: This should be changed to n_feat_dim = hete.num_node_features(n_type) and s_feat_dim = hete.num_node_features(s_type) # If features are not an identity matrix. n_feat_dim = hete.num_node_features(n_type) s_feat_dim = hete.num_node_features(s_type) convs1[message_type] = conv(n_feat_dim, hidden_size, s_feat_dim) convs2[message_type] = conv(hidden_size, hidden_size) return [convs1, convs2] class HeteroGNN(torch.nn.Module): def __init__(self, convs, hetero, hidden_size): super(HeteroGNN, self).__init__() self.convs1 = HeteroConv(convs[0]) # Wrap the heterogeneous GNN layers self.convs2 = HeteroConv(convs[1]) self.bns1 = nn.ModuleDict() self.bns2 = nn.ModuleDict() self.relus1 = nn.ModuleDict() self.loss_fn = torch.nn.BCEWithLogitsLoss() for node_type in hetero.node_types: self.bns1[node_type] = nn.BatchNorm1d(hidden_size) self.bns2[node_type] = nn.BatchNorm1d(hidden_size) self.relus1[node_type] = nn.LeakyReLU() def getEmbeddings(self, data, training = True): x = data.node_feature edge_index = data.edge_index """edge_weight = data.edge_feature keys = [key for key in edge_weight] for key in keys: newKey = key[1] edge_weight[newKey] = edge_weight[key] del edge_weight[key]""" """x = self.convs1(x, edge_index, edge_weight) x = forward_op(x, self.bns1) x = forward_op(x, self.relus1) x = self.convs2(x, edge_index, edge_weight) x = forward_op(x, self.bns2)""" x = self.convs1(x, edge_index) x = forward_op(x, self.bns1) x = forward_op(x, self.relus1) x = self.convs2(x, edge_index) x = forward_op(x, self.bns2) return x def forward(self, data): x = self.getEmbeddings(data) pred = {} pred2 = {} for message_type in edges: nodes_first = None nodes_second = None if message_type == ('disorder', 'dis_dru_the', 'drug'): nodes_first = torch.index_select(x['disorder'], 0, data.edge_label_index[message_type][0, :].long()) nodes_second = torch.index_select(x['drug'], 0, data.edge_label_index[message_type][1, :].long()) elif message_type == ('drug', 'dru_dis_the', 'disorder'): nodes_first = torch.index_select(x['drug'], 0, data.edge_label_index[message_type][0, :].long()) nodes_second = torch.index_select(x['disorder'], 0, data.edge_label_index[message_type][1, :].long()) elif message_type == ('disorder', 'dse_sym', 'disorder'): nodes_first = torch.index_select(x['disorder'], 0, data.edge_label_index[message_type][0, :].long()) nodes_second = torch.index_select(x['disorder'], 0, data.edge_label_index[message_type][1, :].long()) elif message_type == ('disorder', 'sym_dse', 'disorder'): nodes_first = torch.index_select(x['disorder'], 0, data.edge_label_index[message_type][0, :].long()) nodes_second = torch.index_select(x['disorder'], 0, data.edge_label_index[message_type][1, :].long()) pred[message_type] = torch.sigmoid(torch.sum(nodes_first * nodes_second, dim=-1)) return pred def predict_all(self, data): x = self.getEmbeddings(data) pred = {} for message_type in edges: nodes_first = None nodes_second = None if message_type == ('disorder', 'dis_dru_the', 'drug'): nodes_first = x['disorder'] nodes_second = x['drug'] elif message_type == ('disorder', 'dse_sym', 'disorder'): nodes_first = x['disorder'] nodes_second = x['disorder'] for i, elem in enumerate(nodes_first): pred[message_type, i] = torch.sigmoid(torch.sum(elem * nodes_second, dim=-1)) return pred def predict_all_type(self, data, type, id): x = self.getEmbeddings(data) pred = {} for message_type in edges: nodes_first = None nodes_second = None if message_type == ('disorder', 'dis_dru_the', 'drug') and type == 'disease': nodes_first = x['disorder'][id].unsqueeze(0) nodes_second = x['drug'] elif message_type == ('disorder', 'dis_dru_the', 'drug') and type == 'drug': nodes_first = x['disorder'] nodes_second = x['drug'][id].unsqueeze(0) elif message_type == ('disorder', 'dse_sym', 'disorder'): nodes_first = x['disorder'] nodes_second = x['disorder'] for i, elem in enumerate(nodes_first): pred[message_type, i] = torch.sigmoid(torch.sum(elem * nodes_second, dim=-1)) return pred def pred(self, data, eid): x = self.getEmbeddings(data, False) heads = x['disorder'] tails = x['drug'] pred = [] for head, tail in zip(eid[0], eid[1]): pred.append(torch.sigmoid(torch.sum(heads[head] * tails[tail], dim=-1))) return pred def loss(self, pred, y): loss = 0 for key in pred: loss += self.loss_fn(pred[key], y[key].type(pred[key].dtype)) return loss def train(model, dataloaders, optimizer, scheduler, args): min_value = 1 best_model = model t_accu = [] v_accu = [] e_accu = [] lossL = [] lossLV = [] criterion = torch.nn.BCEWithLogitsLoss() for epoch in range(1, args.epochs + 1): for batch in dataloaders['train']: batch.to(args.device) model.train() optimizer.zero_grad() pred = model(batch) loss = model.loss(pred, batch.edge_label) if epoch == args.epochs: loss.backward() else: loss.backward(retain_graph=True) optimizer.step() log = ' Epoch: {:03d}, Loss: {:.4f}, Train: {:.4f}, Val: {:.4f}, Test: {:.4f}' accs = test(model, dataloaders, args) t_accu.append(accs['train']) v_accu.append(accs['val']) e_accu.append(accs['test']) lossL.append(loss.cpu().detach().numpy()) lossLV.append(accs['valLoss'].cpu().detach().numpy()) #scheduler.step(loss) print(log.format(epoch, loss.item(), accs['train'], accs['val'], accs['test'])) if min_value > lossLV[-1] and epoch >= args.epochs - 200: min_value = lossLV[-1] best_model = copy.deepcopy(model) best_it = epoch torch.cuda.empty_cache() log = 'Best: Iteration {:d} Train: {:.4f}, Val: {:.4f}, Test: {:.4f}' accs = test(best_model, dataloaders, args) print(log.format(best_it, accs['train'], accs['val'], accs['test'])) plt.plot(lossL) plt.plot(lossLV) plt.plot(best_it,min_value,'k+', linewidth=100) plt.title('Loss Evolution') plt.legend(['Train', 'Validation','Best Model']) plt.ylabel('Loss') plt.xlabel('Iteration') plt.xticks(range(0, args.epochs+1, int(args.epochs/10))) plt.yticks((0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85)) plt.savefig('plots/deepSnapPred/metrics/loss.svg', format = 'svg', dpi=1200) plt.clf() return t_accu, v_accu, e_accu, best_model def test(model, dataloaders, args): model.eval() accs = {} for mode, dataloader in dataloaders.items(): acc = 0 num = 0 for batch in dataloader: batch.to(args.device) pred = model(batch) loss = model.loss(pred, batch.edge_label) for key in pred: pred = pred[key].flatten().cpu().detach().numpy() label = batch.edge_label[key].flatten().cpu().numpy() acc += roc_auc_score(label, pred) num += 1 accs[mode] = acc / num accs[mode+'Loss'] = loss return accs def test2(model, test_loader, args): true_labels = [] keys = [] pure_pred_labels = [] for batch in test_loader: batch.to(args.device) pred = model(batch) for key in pred: p = pred[key].cpu().detach().numpy() pure_pred_labels.append(p) true_labels.append(batch.edge_label[key].cpu().detach().numpy()) keys.append(key) return pure_pred_labels, true_labels, keys def main(): args = arg_parse() #constructor = heterograph_construction.DISNETConstructor(device='cuda') constructor = heterograph_construction.DISNETConstructor() hetero, _ = constructor.DISNETHeterographDeepSnap(full=True, withoutRepoDB=True) print(hetero.num_nodes()) print(hetero.num_node_labels()) print(hetero.num_edges()) data = hetero.num_edges() total = 0 for key in data: total += data[key] print("Total: ", total) for key in data: print("Contribution of ", key, " is ", data[key]/total) edge_train_mode = args.mode print('edge train mode: {}'.format(edge_train_mode)) dataset = GraphDataset( [hetero], task='link_pred', edge_train_mode=edge_train_mode, edge_message_ratio=args.edge_message_ratio, edge_negative_sampling_ratio=1 ) dataset_train, dataset_val, dataset_test = dataset.split( transductive=True, split_ratio=[0.8, 0.1, 0.1], shuffle=True ) train_loader = DataLoader( dataset_train, collate_fn=Batch.collate(), batch_size=1 ) val_loader = DataLoader( dataset_val, collate_fn=Batch.collate(), batch_size=1 ) test_loader = DataLoader( dataset_test, collate_fn=Batch.collate(), batch_size=1 ) dataloaders = { 'train': train_loader, 'val': val_loader, 'test': test_loader } hidden_size = args.hidden_dim convs = generate_convs_link_pred_layers(hetero, HeteroSAGEConv, hidden_size) model = HeteroGNN(convs, hetero, hidden_size).to(args.device) optimizer = torch.optim.Adam( model.parameters(), lr=args.lr, weight_decay=args.weight_decay) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience = 250) #500 print("Started training at", datetime.now().strftime("%H:%M:%S")) _, _, _, model = train(model, dataloaders, optimizer, scheduler, args) print("Finished training at", datetime.now().strftime("%H:%M:%S")) torch.save(model.state_dict(),"./models/modelDeepSnapPred") # Testing model.eval() print("Started testing at", datetime.now().strftime("%H:%M:%S")) pure_pred_labels, true_labels, keys = test2(model, test_loader, args) print("Finished testing at", datetime.now().strftime("%H:%M:%S")) labels = [item for sublist in true_labels for item in sublist] pure_predictions = [item for sublist in pure_pred_labels for item in sublist] plot_roc(labels, pure_predictions, keys[0], "deepSnapPred/") plot_prc(torch.tensor(labels), pure_predictions, keys[0], "deepSnapPred/") if __name__ == '__main__': main()