deepSnapPred.py 11.5 KB
Newer Older
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
import copy
from sklearn.metrics import roc_auc_score
from utilities import plot_roc, plot_prc
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt
import heterograph_construction
from datetime import datetime
from deepsnap.dataset import GraphDataset
from deepsnap.batch import Batch
from deepsnap.hetero_gnn import (
    HeteroSAGEConv,
    HeteroConv,
    forward_op
)

#DeepSnap has some deprecated functions and throws a warning.
import warnings
warnings.filterwarnings("ignore")

ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
22
edges = [('phenotype', 'dis_dru_the', 'drug')]
23
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44

# ---------------------------
# FUNCTIONS
# ---------------------------
def generate_convs_link_pred_layers(hete, conv, hidden_size):
    convs1 = {}
    convs2 = {}
    for message_type in hete.message_types:
        n_type = message_type[0]
        s_type = message_type[2]
       
        # Attention: This should be changed to n_feat_dim = hete.num_node_features(n_type) and s_feat_dim = hete.num_node_features(s_type)
        # If features are not an identity matrix.
        n_feat_dim = hete.num_node_features(n_type)
        s_feat_dim = hete.num_node_features(s_type)

        convs1[message_type] = conv(n_feat_dim, hidden_size, s_feat_dim)
        convs2[message_type] = conv(hidden_size, hidden_size)
    return [convs1, convs2]

class HeteroGNN(torch.nn.Module):
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
45
    def __init__(self, convs, hetero, hidden_size, dropout):
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
46 47 48 49 50 51 52 53
        super(HeteroGNN, self).__init__()

        self.convs1 = HeteroConv(convs[0])  # Wrap the heterogeneous GNN layers
        self.convs2 = HeteroConv(convs[1])
         
        self.bns1 = nn.ModuleDict()
        self.bns2 = nn.ModuleDict()
        
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
54 55
        self.dropout1 = nn.ModuleDict()
        
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
56 57 58 59 60 61 62 63 64
        self.relus1 = nn.ModuleDict()
        
        self.loss_fn = torch.nn.BCEWithLogitsLoss()
        
        
        for node_type in hetero.node_types:
            self.bns1[node_type] = nn.BatchNorm1d(hidden_size)
            self.bns2[node_type] = nn.BatchNorm1d(hidden_size)
            
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
65 66
            self.dropout1[node_type] = nn.Dropout(p=dropout)
            
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
67 68 69 70 71 72
            self.relus1[node_type] = nn.LeakyReLU()
            
            
    def getEmbeddings(self, data, training = True):
        x = data.node_feature
        edge_index = data.edge_index
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
73
        edge_weight = data.edge_feature
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
74 75 76 77
        keys = [key for key in edge_weight]
        for key in keys:
            newKey = key[1]
            edge_weight[newKey] = edge_weight[key]
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
78
            del edge_weight[key]
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
79

ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
80
        x = self.convs1(x, edge_index, edge_weight)
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
81 82 83
        x = forward_op(x, self.bns1)
        x = forward_op(x, self.relus1)
        
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
84
        x = forward_op(x, self.dropout1)
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
85 86 87 88 89 90 91 92 93 94 95 96 97
        x = self.convs2(x, edge_index, edge_weight)
        x = forward_op(x, self.bns2)
       
        return x
      
    def forward(self, data):
        x = self.getEmbeddings(data)

        pred = {}
        pred2 = {}
        for message_type in edges:
            nodes_first = None
            nodes_second = None
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
98 99
            if message_type == ('phenotype', 'dis_dru_the', 'drug'):
                nodes_first = torch.index_select(x['phenotype'], 0, data.edge_label_index[message_type][0, :].long())
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
100 101
                nodes_second = torch.index_select(x['drug'], 0, data.edge_label_index[message_type][1, :].long())

ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
102
            elif message_type == ('drug', 'dru_dis_the', 'phenotype'):
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
103
                nodes_first = torch.index_select(x['drug'], 0, data.edge_label_index[message_type][0, :].long())
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
104
                nodes_second = torch.index_select(x['phenotype'], 0, data.edge_label_index[message_type][1, :].long())
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
105
                
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
106 107 108
            elif message_type == ('phenotype', 'dse_sym', 'phenotype'):
                nodes_first = torch.index_select(x['phenotype'], 0, data.edge_label_index[message_type][0, :].long())
                nodes_second = torch.index_select(x['phenotype'], 0, data.edge_label_index[message_type][1, :].long())
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
109
                
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
110 111 112
            elif message_type == ('phenotype', 'sym_dse', 'phenotype'):
                nodes_first = torch.index_select(x['phenotype'], 0, data.edge_label_index[message_type][0, :].long())
                nodes_second = torch.index_select(x['phenotype'], 0, data.edge_label_index[message_type][1, :].long())
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
113 114 115 116 117 118 119 120 121 122 123 124
                
            pred[message_type] = torch.sigmoid(torch.sum(nodes_first * nodes_second, dim=-1))
       
        return pred
     
    def predict_all(self, data):
        x = self.getEmbeddings(data)
        
        pred = {}
        for message_type in edges:
            nodes_first = None
            nodes_second = None
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
125 126
            if message_type == ('phenotype', 'dis_dru_the', 'drug'):
                nodes_first = x['phenotype']
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
127 128 129 130 131 132 133 134 135 136 137 138 139
                nodes_second = x['drug']
                    
            for i, elem in enumerate(nodes_first):
                pred[message_type, i] = torch.sigmoid(torch.sum(elem * nodes_second, dim=-1))
        return pred
        
    def predict_all_type(self, data, type, id):
        x = self.getEmbeddings(data)

        pred = {}
        for message_type in edges:
            nodes_first = None
            nodes_second = None
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
140 141
            if message_type == ('phenotype', 'dis_dru_the', 'drug') and type == 'phenotype':
                nodes_first = x['phenotype'][id].unsqueeze(0)
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
142 143
                nodes_second = x['drug']

ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
144 145
            elif message_type == ('phenotype', 'dis_dru_the', 'drug') and type == 'drug':
                nodes_first = x['phenotype']
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
146 147
                nodes_second = x['drug'][id].unsqueeze(0)

ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
148 149 150
            elif message_type == ('phenotype', 'dse_sym', 'phenotype'):
                nodes_first = x['phenotype']
                nodes_second = x['phenotype']
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
151 152 153 154 155 156 157

            for i, elem in enumerate(nodes_first):
                pred[message_type, i] = torch.sigmoid(torch.sum(elem * nodes_second, dim=-1))
        return pred
        
    def pred(self, data, eid):
        x = self.getEmbeddings(data, False)
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
158
        heads = x['phenotype']
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
159 160 161 162 163 164 165 166 167 168 169 170 171
        tails = x['drug']
        pred = []
        for head, tail in zip(eid[0], eid[1]):
            pred.append(torch.sigmoid(torch.sum(heads[head] * tails[tail], dim=-1)))
        return pred

    def loss(self, pred, y):
        loss = 0
        for key in pred:
            loss += self.loss_fn(pred[key], y[key].type(pred[key].dtype))
        return loss


ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
172
def train(model, dataloaders, optimizer, epochs):
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
173 174 175 176 177 178 179 180
    min_value = 1
    best_model = model
    t_accu = []
    v_accu = []
    e_accu = []
    lossL = []
    lossLV = []
    criterion = torch.nn.BCEWithLogitsLoss()
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
181
    for epoch in range(1, epochs + 1):
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
182
        for batch in dataloaders['train']:
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
183
            batch.to(device)
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
184 185 186 187 188 189 190 191
            model.train()
            
            optimizer.zero_grad()
            
            pred = model(batch)
                        
            loss = model.loss(pred, batch.edge_label)

ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
192
            if epoch == epochs:
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
193 194 195 196 197
                loss.backward()
            else:
                loss.backward(retain_graph=True)
            optimizer.step()

ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
198 199 200 201 202
            accs = test(model, dataloaders)

            #log = '     Epoch: {:03d}, Loss: {:.4f}, Train: {:.4f}, Val: {:.4f}, Test: {:.4f}'
            #print(log.format(epoch, loss.item(), accs['train'], accs['val'], accs['test']))
            
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
203 204 205 206 207 208 209
            t_accu.append(accs['train'])
            v_accu.append(accs['val'])
            e_accu.append(accs['test'])
            
            lossL.append(loss.cpu().detach().numpy())
            lossLV.append(accs['valLoss'].cpu().detach().numpy())

ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
210
            if min_value > lossLV[-1] and epoch >= epochs - 200:
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
211 212 213 214 215 216 217
                min_value = lossLV[-1]
                best_model = copy.deepcopy(model)
                best_it = epoch
                
            torch.cuda.empty_cache()
            
    log = 'Best: Iteration {:d} Train: {:.4f}, Val: {:.4f}, Test: {:.4f}'
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
218
    accs = test(best_model, dataloaders)
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
219 220 221 222 223 224 225 226 227
    print(log.format(best_it, accs['train'], accs['val'], accs['test']))
     
    plt.plot(lossL)
    plt.plot(lossLV)
    plt.plot(best_it,min_value,'k+', linewidth=100)
    plt.title('Loss Evolution')
    plt.legend(['Train', 'Validation','Best Model'])
    plt.ylabel('Loss')
    plt.xlabel('Iteration')
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
228
    plt.xticks(range(0, epochs+1, int(epochs/10)))
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
229
    plt.yticks((0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85))
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
230
    plt.savefig('plots/behor/metrics/loss.svg', format='svg', dpi=1200)
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
231 232 233 234 235
    plt.clf()
    
    return t_accu, v_accu, e_accu, best_model


ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
236
def test(model, dataloaders):
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
237 238 239 240 241 242
    model.eval()
    accs = {}
    for mode, dataloader in dataloaders.items():
        acc = 0
        num = 0
        for batch in dataloader:
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
243
            batch.to(device)
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
244 245 246 247 248 249 250 251 252 253 254 255 256 257
            pred = model(batch)
            loss = model.loss(pred, batch.edge_label)
            for key in pred:
                pred = pred[key].flatten().cpu().detach().numpy()
                label = batch.edge_label[key].flatten().cpu().numpy()
             
                acc += roc_auc_score(label, pred)
                num += 1

        accs[mode] = acc / num
        accs[mode+'Loss'] = loss
    return accs


ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
258
def test2(model, test_loader):
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
259 260 261 262
    true_labels = []
    keys = []
    pure_pred_labels = []
    for batch in test_loader:
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
263
        batch.to(device)
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
264 265 266 267 268 269 270 271 272
        pred = model(batch)
        for key in pred:
            p = pred[key].cpu().detach().numpy()
            pure_pred_labels.append(p)
            true_labels.append(batch.edge_label[key].cpu().detach().numpy())
            keys.append(key)
    return pure_pred_labels, true_labels, keys


ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
273 274 275
def main(epochs, hidden_dim, lr, weight_decay, dropout):
    constructor = heterograph_construction.DISNETConstructor(device='cuda')
    hetero, _ = constructor.DISNETHeterographDeepSnap(full=False, withoutRepoDB=True)
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
276

ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
277
    edge_train_mode = 'disjoint'
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
278 279 280 281 282
    print('edge train mode: {}'.format(edge_train_mode))
    dataset = GraphDataset(
        [hetero],
        task='link_pred',
        edge_train_mode=edge_train_mode,
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
283
        edge_message_ratio=0.8,
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303
        edge_negative_sampling_ratio=1
    )

    dataset_train, dataset_val, dataset_test = dataset.split(
        transductive=True, split_ratio=[0.8, 0.1, 0.1], shuffle=True
    )

    train_loader = DataLoader(
        dataset_train, collate_fn=Batch.collate(), batch_size=1
    )
    val_loader = DataLoader(
        dataset_val, collate_fn=Batch.collate(), batch_size=1
    )
    test_loader = DataLoader(
        dataset_test, collate_fn=Batch.collate(), batch_size=1
    )
    dataloaders = {
        'train': train_loader, 'val': val_loader, 'test': test_loader
    }

ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
304 305
    convs = generate_convs_link_pred_layers(hetero, HeteroSAGEConv, hidden_dim)
    model = HeteroGNN(convs, hetero, hidden_dim, dropout).to(device)
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
306
    optimizer = torch.optim.Adam(
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
307
        model.parameters(), lr=lr, weight_decay=weight_decay)
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
308 309

    print("Started training at", datetime.now().strftime("%H:%M:%S"))
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
310
    _, _, _, model = train(model, dataloaders, optimizer, epochs)
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
311 312
    print("Finished training at", datetime.now().strftime("%H:%M:%S"))

ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
313
    torch.save(model.state_dict(), "models/behor")
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
314 315 316 317 318

    # Testing
    model.eval()

    print("Started testing at", datetime.now().strftime("%H:%M:%S"))
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
319
    pure_pred_labels, true_labels, keys = test2(model, test_loader)
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
320 321 322 323 324
    print("Finished testing at", datetime.now().strftime("%H:%M:%S"))

    labels = [item for sublist in true_labels for item in sublist]
    pure_predictions = [item for sublist in pure_pred_labels for item in sublist]
    
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
325 326
    plot_roc(labels, pure_predictions, keys[0], "behor/")
    plot_prc(torch.tensor(labels), pure_predictions, keys[0], "behor/")
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
327
    return model
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
328 329 330


if __name__ == '__main__':
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
331
    main(2752, 107, 0.0008317, 0.006314, 0.8)