Commit 9e2d05a0 authored by ADRIAN  AYUSO MUNOZ's avatar ADRIAN AYUSO MUNOZ

Initial Commit

parents
# Graph Deep Learning for Drug Repurposing.
## Content in each directory:
- **data**: Data to build DISNET's graph.
- **documentation**: Instructions to install the needed libraries.
- **metrics**: Training, testing and RepoDB validating ROC & PRC.
- **models**: Trained models.
- **results**: Result files of the RepoDB test and the distribution plots (once these are generated).
- **testData**: Data to validate model using RepoDB.
- **Code files:**
- autoencoder (drug molecular embedder model)
- dmsr (drug repurposing model)
- drug_embedding_generator (generates drug embeddings using SMILES representation).
- heterograph_construction (build graph)
- testRepoDB (validate model using RepoDB)
- topN (get topN new predictions)
- utilities (plotting utilities)
## Summary
Repository of Adrián Ayuso-Muñoz's master's final project "Graph Deep Learning for Drug Repurposing".
\ No newline at end of file
import torch
from torch_geometric.nn import SAGEConv
import random
import math
"""
Graph Convolutional Network Encoder.
https://arxiv.org/abs/1611.07308
"""
class GCNEncoder(torch.nn.Module):
"""
It instantiates the GCNEncoder and all its layers (two convolutions and leaky ReLU).
Input:
in_channels: Size of the input embeddings.
out_channels: Size of the output embeddings.
"""
def __init__(self, in_channels, out_channels):
super(GCNEncoder, self).__init__()
self.conv1 = SAGEConv(in_channels, 2 * out_channels) # cached only for transductive learning
self.conv2 = SAGEConv(2 * out_channels, out_channels) # cached only for transductive learning
self.lrelu = torch.nn.LeakyReLU()
"""
It applies the two convolutions to the graph, applying LReLU in between.
Input:
x: All the features of the graph's nodes.
edge_index: Two lists which contain the edges of the graph (first contains heads and second tails)
Output:
Encoded embeddings.
"""
def forward(self, x, edge_index):
x = self.conv1(x, edge_index)
x = self.lrelu(x)
return self.conv2(x, edge_index)
"""
Class to wrap the training and testing functions.
"""
class Trainer:
"""
It instantiates the Trainer saving all the necessary data to train and test the model.
Input:
model: Model to be trained.
optimizer: Optimizer to train the model.
dataset: Dataset on which the model will be trained.
"""
def __init__(self, model, optimizer, dataset):
self.model = model
self.optimizer = optimizer
self.dataset = dataset
random.shuffle(self.dataset)
self.trainD = self.dataset[0:2864]
self.testD = self.dataset[2864:]
"""
It trains the model over a batch and computes its loss.
Input:
x: Features of the graph's nodes.
edge_index: Two lists which contain the edges of the graph (first contains heads and second tails)
neg_edge_index: Two lists which contain negative edges, not present in the graph (first contains heads and
second tails)
Output:
Model's loss for the given batch, in case of error it returns 1.
"""
def train(self, x, edge_index, neg_edge_index):
self.model.train()
self.optimizer.zero_grad()
z = self.model.encode(x, edge_index)
loss = self.model.recon_loss(z, edge_index, neg_edge_index)
if not math.isnan(loss):
loss.backward()
self.optimizer.step()
return float(loss)
else:
return 1
"""
It tests the model and computes its loss.
Input:
x: Features of the graph's nodes.
edge_index: Two lists which contain the edges of the graph (first contains heads and second tails)
neg_edge_index: Two lists which contain negative edges, not present in the graph (first contains heads and
second tails)
Output:
Metrics for the test set.
"""
def test(self, x, edge_index, neg_edge_index):
self.model.eval()
with torch.no_grad():
z = self.model.encode(x, edge_index)
return self.model.test(z, edge_index, neg_edge_index)
"""
It trains the model over a determined number of epochs and computes the average of its metrics (AUROC, PRAUC and loss)
Input:
epochs: Number of epochs to train the model.
"""
def fit(self, epochs):
for epoch in range(1, epochs + 1):
auc, ap, i, j, loss = 0, 0, 0, 0, 0
random.shuffle(self.trainD)
for data in self.trainD:
loss += self.train(data.x, data.edge_index, data.neg_edge_index)
i += 1
for data in self.testD:
if len(data.edge_index[0]) != 0:
res = self.test(data.x, data.edge_index, data.neg_edge_index)
auc += res[0]
ap += res[1]
j += 1
print('Epoch: {:03d}, AUC: {:.4f}, AP: {:.4f}, Loss: {:0.3f}'.format(epoch, auc / j, ap / j, loss / i))
This source diff could not be displayed because it is too large. You can view the blob instead.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
dis pat
C0020538 WP554
C0020538 WP2197
C0020538 WP4756
C0018799 WP1544
C0018799 WP1528
C0018799 WP516
C0018799 WP1591
C0027947 WP229
C0013369 WP229
C0006826 WP229
C0002892 WP1533
C0030567 WP2371
C0752347 WP2371
C0524851 WP2371
C1285162 WP2371
C0030567 WP4222
C0002395 WP4222
C0036341 WP4222
C0020179 WP4222
C0028768 WP4222
C0041671 WP4222
C0040517 WP4222
C0030567 WP2355
C0002395 WP2355
C0024517 WP2355
C0154409 WP2355
C0003125 WP2355
C0006826 WP2446
C0006826 WP1742
C0006826 WP4747
C0006826 WP4262
C0006826 WP4284
C0006826 WP3858
C0006826 WP4565
C0027819 WP4565
C0006826 WP3678
C0019562 WP3678
C0006826 WP673
C0002395 WP673
C0026769 WP673
C0006826 WP2363
C0024623 WP2363
C0038356 WP2363
C0153421 WP2363
C0153422 WP2363
C0153423 WP2363
C0278701 WP2363
C0006826 WP4205
C1306837 WP4205
C1336078 WP4205
C0006826 WP1971
C0006826 WP4534
C0006826 WP3613
C0006826 WP4582
C0006826 WP3617
C0006826 WP4559
C0006826 WP4585
C0006826 WP3967
C0919267 WP3967
C1140680 WP3967
C1299247 WP3967
C0006826 WP2870
C0027765 WP2870
C0006826 WP3612
C0006826 WP4172
C0006826 WP699
C0006826 WP3614
C0006826 WP4658
C0024624 WP4658
C0153491 WP4658
C0153492 WP4658
C0153493 WP4658
C0006826 WP2828
C0005684 WP2828
C0005695 WP2828
C1306837 WP2828
C1336078 WP2828
C0005686 WP2828
C0006826 WP4485
C0006826 WP1601
C0006826 WP3611
C0006826 WP3298
C0154564 WP3298
C0006826 WP2261
C0017636 WP2261
C0278878 WP2261
C1514422 WP2261
C0006118 WP2261
C0153633 WP2261
C0220624 WP2261
C0750974 WP2261
C0750979 WP2261
C0006826 WP2526
C0006826 WP2868
C0919267 WP2868
C1140680 WP2868
C1299247 WP2868
C0006826 WP4016
C0007114 WP4016
C0006826 WP3879
C0006142 WP3879
C0006826 WP4542
C0038436 WP2431
C0037928 WP2431
C0038436 WP2848
C0037928 WP2848
C0023772 WP143
C0023772 WP206
C0011849 WP3635
C0011849 WP690
C0011849 WP1584
C0020456 WP1584
C0002395 WP3945
C0002395 WP2059
C0007222 WP430
C0007222 WP536
C0007222 WP3668
C0021359 WP34
C0041295 WP4197
C1864436 WP4767
C0021400 WP1904
C0003723 WP1904
C0042769 WP1904
C0036341 WP4875
C0036341 WP4905
C0024517 WP4905
C0154409 WP4905
C0005586 WP4905
C0005587 WP4905
C0024713 WP4905
C0236780 WP4905
C0033578 WP3872
C0376358 WP3872
C0033578 WP4301
C0376358 WP4301
C0033578 WP3982
C0376358 WP3982
C0033578 WP3981
C0376358 WP3981
C1175175 WP4864
C0003723 WP4864
C0042769 WP4864
C1175175 WP4863
C0003723 WP4863
C0042769 WP4863
C1175175 WP4861
C0003723 WP4861
C0042769 WP4861
C1175175 WP4912
C0003723 WP4912
C0042769 WP4912
C1175175 WP4853
C1175175 WP4877
C0003723 WP4877
C0042769 WP4877
C1175175 WP4880
C0003723 WP4880
C0042769 WP4880
C1175175 WP4868
C0003723 WP4868
C0042769 WP4868
C1175175 WP4891
C0003723 WP4891
C0042769 WP4891
C0034069 WP4891
C1175175 WP4884
C0003723 WP4884
C0042769 WP4884
C1175175 WP4883
C0003723 WP4883
C0042769 WP4883
C1175175 WP4860
C0003723 WP4860
C0042769 WP4860
C1175175 WP4876
C0003723 WP4876
C0042769 WP4876
C0346629 WP3969
C0009404 WP3969
C0346629 WP4290
C0009404 WP4290
C0346629 WP4239
C0009404 WP4239
C0346629 WP4258
C0009404 WP4258
C0346629 WP4216
C0009404 WP4216
C0003723 WP4298
C0042769 WP4298
C0027059 WP4298
C0003723 WP4655
C0042769 WP4655
C0004623 WP4655
C0003723 WP3865
C0042769 WP3865
C0003723 WP4799
C0042769 WP4799
C0003723 WP4217
C0042769 WP4217
C0282687 WP4217
C0282312 WP3849
C0021364 WP4673
C0030297 WP3680
C0153458 WP3680
C0153459 WP3680
C0153460 WP3680
C0153463 WP3680
C0023418 WP3658
C0733682 WP4790
C1845168 WP4790
C3540852 WP4790
C0878676 WP4156
C1291564 WP4156
C0268468 WP4156
C0002888 WP4156
C1291564 WP4220
C0342687 WP4220
C0220993 WP4292
C0268616 WP4292
C0268624 WP4292
C2931746 WP4292
C0019880 WP4292
C0017606 WP3874
C0033300 WP4320
C0033300 WP4879
C0410189 WP4879
C1720860 WP4879
C1720861 WP4879
C1720859 WP4879
C0271694 WP4879
C0033141 WP4879
C0036529 WP4879
C0878544 WP4879
C0033300 WP4299
C0012634 WP4299
C0600327 WP3863
C0035372 WP4453
C0035372 WP4312
C0035372 WP3584
C0036875 WP4814
C0036875 WP4842
C0549622 WP4842
C0022658 WP4150
C0022658 WP4758
C0022658 WP4838
C0035078 WP4838
C0028064 WP4545
C0238052 WP4545
C0028064 WP4153
C0012634 WP4153
C0268255 WP4153
C0023522 WP4153
C0023521 WP4153
C0002986 WP4153
C0017205 WP4153
C0017083 WP4153
C0039373 WP4153
C0268274 WP4153
C0085131 WP4153
C0036161 WP4153
C0268275 WP4153
C0524851 WP4760
C1285162 WP4760
C0206307 WP4519
C0029442 WP1531
C0035579 WP1531
C0005684 WP2291
C0005695 WP2291
C0005684 WP3670
C0005695 WP3670
C0003469 WP3947
C0003469 WP3944
C0017636 WP3593
C0278878 WP3593
C1514422 WP3593
C0024623 WP2361
C0038356 WP2361
C0153421 WP2361
C0153422 WP2361
C0153423 WP2361
C0278701 WP2361
C0038644 WP706
C0017551 WP1604
C0002736 WP2447
C0029434 WP4786
C0015625 WP3569
C0019247 WP4804
C0019247 WP3674
C0019247 WP4856
C0027765 WP4856
C0007959 WP4856
C0007194 WP2795
C0949658 WP2795
C1378703 WP4204
C0028754 WP3965
C0028754 WP3407
C0028754 WP2865
C1257763 WP2865
C0007115 WP3859
C0040136 WP3859
C0014544 WP3871
C0014544 WP4829
C0014544 WP4313
C0919267 WP3972
C1140680 WP3972
C1299247 WP3972
C0919267 WP4560
C1140680 WP4560
C1299247 WP4560
C0919267 WP4400
C1140680 WP4400
C1299247 WP4400
C0919267 WP4397
C1140680 WP4397
C1299247 WP4397
C0919267 WP3301
C1140680 WP3301
C1299247 WP3301
C0919267 WP3303
C1140680 WP3303
C1299247 WP3303
C0279702 WP4018
C0019693 WP3300
C0019693 WP3414
C0008924 WP3655
C0006015 WP3679
C2931845 WP4577
C0154319 WP4008
C1306837 WP4241
C1336078 WP4241
C0015624 WP4917
C0341703 WP4917
C0265268 WP4787
C0005940 WP4787
C0268547 WP4595
C0154246 WP4595
C0268542 WP4595
C0175683 WP4595
C0268548 WP4595
C0751753 WP4595
C0268547 WP4583
C0154246 WP4583
C0268542 WP4583
C0175683 WP4583
C0268548 WP4583
C0751753 WP4583
C0268547 WP4571
C0154246 WP4571
C0268542 WP4571
C0175683 WP4571
C0268548 WP4571
C0751753 WP4571
C0012634 WP2491
C0701163 WP4523
C0033804 WP4523
C0342488 WP4523
C1257958 WP3934
C0023473 WP2290
C0023473 WP3640
C0034069 WP3624
C0025521 WP4521
C0282577 WP4521
C0025521 WP4506
C0002066 WP4506
C0025521 WP4288
C0346302 WP4194
C0410189 WP4535
C0017668 WP2572
C1332167 WP3651
C0020474 WP4522
C0039292 WP4522
C0020597 WP4522
C0268118 WP4507
C1863688 WP4507
C1854989 WP4507
C1854990 WP4507
C0281361 WP4263
C0007134 WP4206
C0206654 WP4206
C0018051 WP4872
C0018054 WP4872
C0949595 WP4872
C0023434 WP4399
C0020179 WP3853
C0268524 WP4518
C0002878 WP4518
C0004623 WP2665
C0004153 WP3926
C0153368 WP1589
C0040128 WP1981
C0007103 WP4155
C0014170 WP4155
C0010308 WP4746
C0342200 WP4746
C0019163 WP4666
C0268120 WP4224
C3665382 WP4224
C0023374 WP4224
C0268124 WP4224
C0029928 WP4871
C0017411 WP4871
C0020630 WP4228
C0220743 WP4228
C0032897 WP3998
C0162635 WP3998
C0011570 WP4698
C0025202 WP4685
C0024624 WP4255
C0153491 WP4255
C0153492 WP4255
C0153493 WP4255
C1959620 WP4584
C3495551 WP4584
C1959620 WP4225
C3495551 WP4225
C0016667 WP4549
C0349788 WP2118
C0278752 WP4540
C0345967 WP4540
C0392400 WP4540
C1332338 WP4540
C0220704 WP4657
C0162534 WP3995
C0162666 WP4236
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Development was carried out with Python 3.8.10.
To install all the libraries run the following command:
pip install -r libsImport.txt
Some libraries will produce an error, in that case delete the corresponding line and run the command again.
**IMPORTANT**: PyTorch (torch) will be installed as a dependency of other packages, it should be uninstalled and installed again using the command shown below.
**IMPORTANT**: Most libraries are compiled to use CUDA 11.3, if you are using a different version please adapt the installation.
Here are some libraries that usually produce errors and the way to install them:
- **DGL**: pip install dgl-cu113==0.7.2 -f https://data.dgl.ai/wheels/repo.html
- **SNAP**: pip install snap-stanford
- **PyTorch**: pip install torch==1.10.2 torchvision==0.11.3 torchaudio==0.10.2 --extra-index-url https://download.pytorch.org/whl/cu113
- **PyTorchExtra**: pip install torch-scatter==2.0.9 torch-sparse==0.6.12 torch-cluster==1.5.9 torch-spline-conv==1.2.1 torch-geometric==2.0.3 -f https://data.pyg.org/whl/torch-1.10.0+cpu.html
\ No newline at end of file
asttokens==2.0.5
attrs==21.4.0
certifi==2021.10.8
charset-normalizer==2.0.12
click==8.1.2
cycler==0.11.0
deepsnap==0.2.1
dgl-cu113==0.7.2
executing==0.8.2
Flask==2.1.1
fonttools==4.29.1
googledrivedownloader==0.4
idna==3.3
importlib-metadata==4.11.3
iniconfig==1.1.1
install==1.3.5
isodate==0.6.1
itsdangerous==2.1.2
Jinja2==3.0.3
joblib==1.1.0
kiwisolver==1.3.2
littleutils==0.2.2
MarkupSafe==2.1.0
matplotlib==3.5.1
networkx==2.6.3
numpy==1.22.2
nvidia-smi==0.1.3
packaging==21.3
pandas==1.4.1
Pillow==9.0.1
pluggy==1.0.0
py==1.11.0
pyparsing==3.0.7
pytest==7.0.1
python-dateutil==2.8.2
pytz==2021.3
PyYAML==6.0
rdflib==6.1.1
requests==2.27.1
scikit-learn==1.0.2
scipy==1.8.0
seaborn==0.11.2
six==1.16.0
sklearn==0.0
snap-stanford==6.0.0
sorcery==0.2.2
threadpoolctl==3.1.0
tomli==2.0.1
torch==1.10.2+cu113
torch-cluster==1.5.9
torch-geometric==2.0.3
torch-scatter==2.0.9
torch-sparse==0.6.12
torch-spline-conv==1.2.1
torchaudio==0.10.2+cu113
torchvision==0.11.3+cu113
tqdm==4.62.3
typing-extensions==4.1.1
urllib3==1.26.8
Werkzeug==2.1.1
wrapt==1.13.3
yacs==0.1.8
zipp==3.8.0
import numpy as np
from pysmiles import read_smiles
import pandas as pd
import mysql.connector
import torch
from autoencoder import GCNEncoder, Trainer
from torch_geometric.nn import GAE
import itertools
import random
from pubchempy import *
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # It defines whether to execute on cpu or gpu.
"""
Class that contains the data associated to the drug graphs.
"""
class DrugGraph:
"""
It instantiates the drug molecular graph.
Input:
graph: Graph.
feats: Feats of the graph nodes..
edge_index: Two lists which contain the edges of the graph (first contains heads and second tails)
neg_edge_index: Two lists which contain negative edges, not present in the graph (first contains heads and
second tails)
"""
def __init__(self, graph, feats, edge_index, neg_edge_index):
self.graph = graph
self.x = feats
self.edge_index = edge_index
self.neg_edge_index = neg_edge_index
"""
It generates a CSV with the SMILES representations of the DISNET's graph drugs.
"""
def getSmiles():
cnx = mysql.connector.connect(user='edsss_usr', password='1AftIRWJa93P',
host='ares.ctb.upm.es', port='30100',
database='disnet_drugslayer')
cursor = cnx.cursor()
query = ("SELECT drug_id, drug_name, chemical_structure FROM drug;")
cursor.execute(query)
newDf = pd.DataFrame(cursor.fetchall())
newDf.columns = ['id', 'name', 'stru']
cursor.close()
cnx.close()
newDf.to_csv("data/druStruc.tsv", sep='\t', index=False)
"""
It transforms the SMILES representations to their graph representations. In case of error instead of the graph a 0 is
given.
Input:
df: Dataframe containing drug's name and SMILES.
complete: Decides if drugs that result in errors should be completed or not (Searching for them in PubChem).
Output:
Array with all the NetworkX objects representing the drugs.
"""
def getGraph(df, complete=False):
networks = []
for name, smiles in zip(df['name'].tolist(),
df['stru'].tolist()): # Stereochemical information that will be discarded.
if isinstance(smiles, str) and smiles != '0':
new = read_smiles(smiles).to_directed()
else:
if complete:
p = get_compounds(name, 'name')
if len(p) > 0:
smiles = p[0].canonical_smiles
new = read_smiles(smiles)
df.loc[df["name"] == name, "stru"] = smiles
else:
df.loc[df["name"] == name, "stru"] = 0
new = 0
else:
new = 0
df.loc[df["name"] == name, "stru"] = 0
networks.append(new)
df.to_csv("data/druStruc.tsv", sep='\t', index=False)
return np.array(networks, dtype=object)
"""
It builds a dataset made of the drugs molecular structures.
Input:
graphs: Array containing all the NetworkX objects representing the drugs.
Output:
Array containing all the DrugGraph objects of the drugs.
"""
def buildDataset(graphs):
res = []
# Dictionary containing the conversion from element to int.
elements = {'Error': -1, 'Ag': 0, 'Al': 1, 'As': 2, 'Au': 3, 'B': 4, 'Ba': 5, 'Bi': 6, 'Br': 7, 'C': 8, 'Ca': 9,
'Cl': 10, 'Co': 11, 'Cr': 12, 'Cu': 13, 'F': 14, 'Fe': 15, 'Ga': 16, 'Gd': 17, 'H': 18, 'He': 19,
'Hg': 20, 'I': 21, 'In': 22, 'K': 23, 'Kr': 24, 'La': 25, 'Li': 26, 'Lu': 27, 'Mg': 28, 'Mn': 29,
'N': 30, 'Na': 31, 'O': 32, 'P': 33, 'Pt': 34, 'Ra': 35, 'Rb': 36, 'S': 37, 'Sb': 38, 'Se': 39,
'Si': 40, 'Sm': 41, 'Sn': 42, 'Sr': 43, 'Tc': 44, 'Ti': 45, 'Tl': 46, 'Xe': 47, 'Yb': 48, 'Zn': 49}
for graph in graphs:
if graph != 0:
# Node Features
feats2 = []
feats = [data[1] for data in graph.nodes(data=True)]
# Need this to make all node feats follow the same format.
for node in feats:
node2 = {'element': elements[node['element']], 'charge': node['charge'], 'aromatic': node['aromatic'],
'hcount': node["hcount"]}
try:
node2['stereo'] = len(node['stereo'])
except KeyError:
node2['stereo'] = 0
feats2.append(list(node2.values()))
# Generate all possible edges.
neg_edges = list(itertools.combinations(list(range(0, len(feats2))), 2))
# Remove from the negative edge list those that appear in the original graph.
e0 = []
e1 = []
for e in list(graph.edges):
if e in neg_edges:
neg_edges.remove(e)
e0.append(e[0])
e1.append(e[1])
# If there are more negative edges than positive ones, the number of negatives is reduced to match
# the positives.
eN0 = []
eN1 = []
if len(e1) < len(neg_edges):
neg_edges = random.sample(neg_edges, len(e0))
# Format in edge_index format (head, tail).
for e in neg_edges:
eN0.append(e[0])
eN1.append(e[1])
edges = [e0, e1]
neg_edges = [eN0, eN1]
# If there are no negative edges, one needs to be introduced.
if len(neg_edges[0]) == 0:
neg_edges = [[0], [0]]
else:
# Error case.
feats2 = [[-1, -1, -1, -1, -1]]
edges = [[], []]
neg_edges = [[0], [0]]
res.append(DrugGraph(graph, torch.tensor(feats2, device=device, dtype=torch.float),
torch.tensor(edges, device=device, dtype=torch.int64),
torch.tensor(neg_edges, device=device, dtype=torch.int64)))
return res
"""
It traines the autoencoder model.
Input:
model: Autoencoder model.
optimizer: Optimizer to use along training.
dataset: Dataset to train the autoencoder.
epochs: Number of epochs to train the autoencoder.
"""
def trainAE(model, optimizer, dataset, epochs):
trainer = Trainer(model, optimizer, dataset)
trainer.fit(epochs)
"""
It generates an embedding for a given graph applying all the instantiated layers in the previous function. The embedding
is the code of the autoencoder.
Input:
model: Autoencoder.
dataset: Dataset to generate embeddings from.
Output:
Drug molecular structure embeddings.
"""
def getEmbed(model, dataset):
embeddings = []
model.eval()
model = model.to('cpu')
with torch.no_grad():
for data in dataset:
embeddings.append(torch.mean(model.encode(data.x.cpu(), data.edge_index.cpu()), dim=0))
return embeddings
if __name__ == '__main__':
getSmiles()
df = pd.read_csv('data/druStruc.tsv', sep='\t')
graphs = getGraph(df)
dataset = buildDataset(graphs)
dataset2 = []
# Those elements without errors are incorporated in the used dataset.
for elem in dataset:
if elem.graph != 0:
dataset2.append(elem)
# Training
model = GAE(GCNEncoder(5, 32))
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
trainAE(model, optimizer, dataset2, 500)
# Getting Embeddings
embeddings = getEmbed(model, dataset)
# Saving model.
torch.save(model.state_dict(), "./models/structureEmbedder")
# Those drugs without embedding are assigned a predefined embedding.
for i in range(len(embeddings)):
if dataset[i].graph == 0:
embeddings[i] = torch.tensor([1] * 32, dtype=torch.float32)
# Save the embeddings.
df["embedding"] = embeddings
df.to_csv("data/features/dru.tsv", sep='\t', index=False)
torch.save(embeddings, 'data/features/dru.pt')
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
File added
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment