import numpy as np
from pysmiles import read_smiles
import pandas as pd
import mysql.connector
import torch
from autoencoder import GCNEncoder, Trainer
from torch_geometric.nn import GAE
import itertools
import random
from pubchempy import *

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')  # It defines whether to execute on cpu or gpu.

Class that contains the data associated to the drug graphs.
class DrugGraph:
    It instantiates the drug molecular graph.
        graph: Graph.
        feats: Feats of the graph nodes..
        edge_index: Two lists which contain the edges of the graph (first contains heads and second tails)
        neg_edge_index: Two lists which contain negative edges, not present in the graph (first contains heads and
            second tails)
    def __init__(self, graph, feats, edge_index, neg_edge_index):
        self.graph = graph
        self.x = feats
        self.edge_index = edge_index
        self.neg_edge_index = neg_edge_index

It generates a CSV with the SMILES representations of the DISNET's graph drugs.
def getSmiles():
    cnx = mysql.connector.connect(user='edsss_usr', password='1AftIRWJa93P',
                                  host='ares.ctb.upm.es', port='30100',
    cursor = cnx.cursor()

    query = ("SELECT drug_id, drug_name, chemical_structure FROM drug;")

    newDf = pd.DataFrame(cursor.fetchall())
    newDf.columns = ['id', 'name', 'stru']


    newDf.to_csv("data/druStruc.tsv", sep='\t', index=False)

It transforms the SMILES representations to their graph representations. In case of error instead of the graph a 0 is 
    df: Dataframe containing drug's name and SMILES.
    complete: Decides if drugs that result in errors should be completed or not (Searching for them in PubChem).
    Array with all the NetworkX objects representing the drugs.
def getGraph(df, complete=False):
    networks = []
    for name, smiles in zip(df['name'].tolist(),
                            df['stru'].tolist()):  # Stereochemical information that will be discarded.
        if isinstance(smiles, str) and smiles != '0':
            new = read_smiles(smiles).to_directed()
            if complete:
                p = get_compounds(name, 'name')
                if len(p) > 0:
                    smiles = p[0].canonical_smiles
                    new = read_smiles(smiles)
                    df.loc[df["name"] == name, "stru"] = smiles
                    df.loc[df["name"] == name, "stru"] = 0
                    new = 0
                new = 0
                df.loc[df["name"] == name, "stru"] = 0

    df.to_csv("data/druStruc.tsv", sep='\t', index=False)
    return np.array(networks, dtype=object)

It builds a dataset made of the drugs molecular structures.
    graphs: Array containing all the NetworkX objects representing the drugs.
    Array containing all the DrugGraph objects of the drugs.
def buildDataset(graphs):
    res = []
    # Dictionary containing the conversion from element to int.
    elements = {'Error': -1, 'Ag': 0, 'Al': 1, 'As': 2, 'Au': 3, 'B': 4, 'Ba': 5, 'Bi': 6, 'Br': 7, 'C': 8, 'Ca': 9,
                'Cl': 10, 'Co': 11, 'Cr': 12, 'Cu': 13, 'F': 14, 'Fe': 15, 'Ga': 16, 'Gd': 17, 'H': 18, 'He': 19,
                'Hg': 20, 'I': 21, 'In': 22, 'K': 23, 'Kr': 24, 'La': 25, 'Li': 26, 'Lu': 27, 'Mg': 28, 'Mn': 29,
                'N': 30, 'Na': 31, 'O': 32, 'P': 33, 'Pt': 34, 'Ra': 35, 'Rb': 36, 'S': 37, 'Sb': 38, 'Se': 39,
                'Si': 40, 'Sm': 41, 'Sn': 42, 'Sr': 43, 'Tc': 44, 'Ti': 45, 'Tl': 46, 'Xe': 47, 'Yb': 48, 'Zn': 49}

    for graph in graphs:
        if graph != 0:
            # Node Features
            feats2 = []
            feats = [data[1] for data in graph.nodes(data=True)]

            # Need this to make all node feats follow the same format.
            for node in feats:
                node2 = {'element': elements[node['element']], 'charge': node['charge'], 'aromatic': node['aromatic'],
                         'hcount': node["hcount"]}
                    node2['stereo'] = len(node['stereo'])
                except KeyError:
                    node2['stereo'] = 0

            # Generate all possible edges.
            neg_edges = list(itertools.combinations(list(range(0, len(feats2))), 2))

            # Remove from the negative edge list those that appear in the original graph.
            e0 = []
            e1 = []
            for e in list(graph.edges):
                if e in neg_edges:

            # If there are more negative edges than positive ones, the number of negatives is reduced to match
            # the positives.
            eN0 = []
            eN1 = []
            if len(e1) < len(neg_edges):
                neg_edges = random.sample(neg_edges, len(e0))

            # Format in edge_index format (head, tail).
            for e in neg_edges:

            edges = [e0, e1]
            neg_edges = [eN0, eN1]
            # If there are no negative edges, one needs to be introduced.
            if len(neg_edges[0]) == 0:
                neg_edges = [[0], [0]]

            # Error case.
            feats2 = [[-1, -1, -1, -1, -1]]
            edges = [[], []]
            neg_edges = [[0], [0]]

        res.append(DrugGraph(graph, torch.tensor(feats2, device=device, dtype=torch.float),
                             torch.tensor(edges, device=device, dtype=torch.int64),
                             torch.tensor(neg_edges, device=device, dtype=torch.int64)))
    return res

It traines the autoencoder model.
    model: Autoencoder model.
    optimizer: Optimizer to use along training.
    dataset: Dataset to train the autoencoder.
    epochs: Number of epochs to train the autoencoder.
def trainAE(model, optimizer, dataset, epochs):
    trainer = Trainer(model, optimizer, dataset)

It generates an embedding for a given graph applying all the instantiated layers in the previous function. The embedding
is the code of the autoencoder.
    model: Autoencoder.
    dataset: Dataset to generate embeddings from.
    Drug molecular structure embeddings.
def getEmbed(model, dataset):
    embeddings = []
    model = model.to('cpu')
    with torch.no_grad():
        for data in dataset:
            embeddings.append(torch.mean(model.encode(data.x.cpu(), data.edge_index.cpu()), dim=0))
    return embeddings

if __name__ == '__main__':

    df = pd.read_csv('data/druStruc.tsv', sep='\t')
    graphs = getGraph(df)

    dataset = buildDataset(graphs)
    dataset2 = []
    # Those elements without errors are incorporated in the used dataset.
    for elem in dataset:
        if elem.graph != 0:

    # Training
    model = GAE(GCNEncoder(5, 32))
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    trainAE(model, optimizer, dataset2, 500)

    # Getting Embeddings
    embeddings = getEmbed(model, dataset)
    # Saving model.
    torch.save(model.state_dict(), "./models/structureEmbedder")

    # Those drugs without embedding are assigned a predefined embedding.
    for i in range(len(embeddings)):
        if dataset[i].graph == 0:
            embeddings[i] = torch.tensor([1] * 32, dtype=torch.float32)

    # Save the embeddings.
    df["embedding"] = embeddings
    df.to_csv("data/features/dru.tsv", sep='\t', index=False)
    torch.save(embeddings, 'data/features/dru.pt')