autoencoder.py 4.07 KB
Newer Older
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114
import torch
from torch_geometric.nn import SAGEConv
import random
import math

"""
Graph Convolutional Network Encoder.
https://arxiv.org/abs/1611.07308
"""
class GCNEncoder(torch.nn.Module):
    """
    It instantiates the GCNEncoder and all its layers (two convolutions and leaky ReLU).
    Input:
        in_channels: Size of the input embeddings.
        out_channels: Size of the output embeddings.
    """
    def __init__(self, in_channels, out_channels):
        super(GCNEncoder, self).__init__()
        self.conv1 = SAGEConv(in_channels, 2 * out_channels)  # cached only for transductive learning
        self.conv2 = SAGEConv(2 * out_channels, out_channels)  # cached only for transductive learning
        self.lrelu = torch.nn.LeakyReLU()

    """
    It applies the two convolutions to the graph, applying LReLU in between.
    Input:
        x: All the features of the graph's nodes.
        edge_index: Two lists which contain the edges of the graph (first contains heads and second tails)
    Output:
        Encoded embeddings.
    """
    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = self.lrelu(x)
        return self.conv2(x, edge_index)

"""
Class to wrap the training  and testing functions.
"""
class Trainer:
    """
    It instantiates the Trainer saving all the necessary data to train and test the model.
    Input:
        model: Model to be trained.
        optimizer: Optimizer to train the model.
        dataset: Dataset on which the model will be trained.
    """
    def __init__(self, model, optimizer, dataset):
        self.model = model
        self.optimizer = optimizer
        self.dataset = dataset
        random.shuffle(self.dataset)
        self.trainD = self.dataset[0:2864]
        self.testD = self.dataset[2864:]

    """
    It trains the model over a batch and computes its loss.
    Input:
        x: Features of the graph's nodes.
        edge_index: Two lists which contain the edges of the graph (first contains heads and second tails)
        neg_edge_index: Two lists which contain negative edges, not present in the graph (first contains heads and 
            second tails)
    Output:
        Model's loss for the given batch, in case of error it returns 1.
    """
    def train(self, x, edge_index, neg_edge_index):
        self.model.train()
        self.optimizer.zero_grad()
        z = self.model.encode(x, edge_index)
        loss = self.model.recon_loss(z, edge_index, neg_edge_index)

        if not math.isnan(loss):
            loss.backward()
            self.optimizer.step()
            return float(loss)
        else:
            return 1

    """
    It tests the model and computes its loss.
    Input:
        x: Features of the graph's nodes.
        edge_index: Two lists which contain the edges of the graph (first contains heads and second tails)
        neg_edge_index: Two lists which contain negative edges, not present in the graph (first contains heads and 
            second tails)
    Output:
        Metrics for the test set.
    """
    def test(self, x, edge_index, neg_edge_index):
        self.model.eval()
        with torch.no_grad():
            z = self.model.encode(x, edge_index)
        return self.model.test(z, edge_index, neg_edge_index)

    """
    It trains the model over a determined number of epochs and computes the average of its metrics (AUROC, PRAUC and loss)
    Input:
        epochs: Number of epochs to train the model.
    """
    def fit(self, epochs):
        for epoch in range(1, epochs + 1):
            auc, ap, i, j, loss = 0, 0, 0, 0, 0
            random.shuffle(self.trainD)
            for data in self.trainD:
                loss += self.train(data.x, data.edge_index, data.neg_edge_index)
                i += 1

            for data in self.testD:
                if len(data.edge_index[0]) != 0:
                    res = self.test(data.x, data.edge_index, data.neg_edge_index)
                    auc += res[0]
                    ap += res[1]
                    j += 1

            print('Epoch: {:03d}, AUC: {:.4f}, AP: {:.4f}, Loss: {:0.3f}'.format(epoch, auc / j, ap / j, loss / i))