Commit 5e5d0f8b authored by ADRIAN  AYUSO MUNOZ's avatar ADRIAN AYUSO MUNOZ

Upload missing files

parent c85f5a42
......@@ -6,8 +6,9 @@
- **documentation**: Instructions to install the needed libraries.
- **graphData**: Data of DISNET's graph.
- **metrics**: Training, testing and RepoDB verification ROC & PRC, also loss evolution for training and validation.
- **logo**: REDIRECTION's logo.
- **models**: Trained models.
- **plots**: Plots of the models' results and metrics.
- **results**: Result files of the RepoDB test and the distribution plots (once these are generated).
- **Code files:**
- deepSnapPred (train & test DeepSnap)
......@@ -15,9 +16,8 @@
- testRepoDB (validate model using RepoDB)
- topN (getting top N new predictions of a model)
- utilities (plotting utilities)
- visualizeDistribution (Checks the distribution of the predictions for a group of diseases, check TFG for more information).
- visualizeEmbeddings (visualization of embeddings)
## Summary
Repository of the AIIM2023 "" paper.
REDIRECTION stands for dRug rEpurposing Disnet lInk pREdiCTION.
\ No newline at end of file
Development was carried out with Python 3.8.10.
To install all the libraries run the following command:
pip install -r libsImport.txt
Some libraries will produce an error, in that case delete the corresponding line and run the command again.
**IMPORTANT**: PyTorch (torch) will be installed as a dependency of other packages, it should be uninstalled and installed again using the command shown below.
**IMPORTANT**: Most libraries are compiled to use CUDA 11.3, if you are using a different version please adapt the installation.
Here are some libraries that usually produce errors and the way to install them:
- **DGL**: pip install dgl-cu113==0.7.2 -f https://data.dgl.ai/wheels/repo.html
- **SNAP**: pip install snap-stanford
- **PyTorch**: pip install torch==1.10.2 torchvision==0.11.3 torchaudio==0.10.2 --extra-index-url https://download.pytorch.org/whl/cu113
- **PyTorchExtra**: pip install torch-scatter==2.0.9 torch-sparse==0.6.12 torch-cluster==1.5.9 torch-spline-conv==1.2.1 torch-geometric==2.0.3 -f https://data.pyg.org/whl/torch-1.10.0+cpu.html
\ No newline at end of file
watch -n0.1 nvidia-smi
\ No newline at end of file
source DISNET/bin/activate
\ No newline at end of file
asttokens==2.0.5
attrs==21.4.0
certifi==2021.10.8
charset-normalizer==2.0.12
click==8.1.2
cycler==0.11.0
deepsnap==0.2.1
dgl-cu113==0.7.2
executing==0.8.2
Flask==2.1.1
fonttools==4.29.1
googledrivedownloader==0.4
idna==3.3
importlib-metadata==4.11.3
iniconfig==1.1.1
install==1.3.5
isodate==0.6.1
itsdangerous==2.1.2
Jinja2==3.0.3
joblib==1.1.0
kiwisolver==1.3.2
littleutils==0.2.2
MarkupSafe==2.1.0
matplotlib==3.5.1
networkx==2.6.3
numpy==1.22.2
nvidia-smi==0.1.3
packaging==21.3
pandas==1.4.1
Pillow==9.0.1
pluggy==1.0.0
py==1.11.0
pyparsing==3.0.7
pytest==7.0.1
python-dateutil==2.8.2
pytz==2021.3
PyYAML==6.0
rdflib==6.1.1
requests==2.27.1
scikit-learn==1.0.2
scipy==1.8.0
seaborn==0.11.2
six==1.16.0
sklearn==0.0
snap-stanford==6.0.0
sorcery==0.2.2
threadpoolctl==3.1.0
tomli==2.0.1
torch==1.10.2+cu113
torch-cluster==1.5.9
torch-geometric==2.0.3
torch-scatter==2.0.9
torch-sparse==0.6.12
torch-spline-conv==1.2.1
torchaudio==0.10.2+cu113
torchvision==0.11.3+cu113
tqdm==4.62.3
typing-extensions==4.1.1
urllib3==1.26.8
Werkzeug==2.1.1
wrapt==1.13.3
yacs==0.1.8
zipp==3.8.0
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
import torch
from deepSnapPred import HeteroGNN, generate_convs_link_pred_layers
from deepsnap.batch import Batch
from torch.utils.data import DataLoader
import argparse
import heterograph_construction
from deepsnap.dataset import GraphDataset
from datetime import datetime
from deepsnap.hetero_gnn import HeteroSAGEConv
def arg_parse():
parser = argparse.ArgumentParser(description='Link pred arguments.')
parser.add_argument('--device', type=str,
help='CPU / GPU device.')
parser.add_argument('--n', type=int,
help='Number of predictions.')
parser.set_defaults(
device='cuda' if torch.cuda.is_available() else 'cpu',
n=6
)
return parser.parse_args()
args = arg_parse()
constructor = heterograph_construction.DISNETConstructor(device=args.device)
toStudy = 'dis_dru_the'
n = args.n #Number of new predictions, top n.
def filterPreds(original, pred, key):
headsO = original.edge_index[key][0, :].long()
new = []
for i, elem in enumerate(pred):
pred_labels = pred[key, i]
head = i
tail = torch.arange(0,len(pred_labels))
indexH = ((headsO == head).nonzero(as_tuple=True)[0])
for index in indexH:
tail = tail[tail != index]
new.append([head, tail, pred_labels[tail].cpu().detach().numpy()])
return new
def getTopNDS(model, original, dataloader, key, n):
print(" Looking for new edges.")
for batch, original in zip(dataloader, original):
batch.to(args.device)
pred = model.predict_all(batch)
new = filterPreds(original, pred, key)
print(" Decoding predictions, this may take a while.")
return constructor.decodePredictions(new, toStudy, n)
def getOriginal():
hetero = constructor.DISNETHeterographDeepSnap(full=True)
dataset = GraphDataset(
[hetero],
task='link_pred',
edge_train_mode='disjoint',
edge_message_ratio=0.8
)
dataset_loader = DataLoader(
dataset, collate_fn=Batch.collate(), batch_size=1
)
return dataset_loader, hetero
def deepSnap():
original, hetero = getOriginal()
conv1, conv2 = generate_convs_link_pred_layers(hetero, HeteroSAGEConv, 32)
model = HeteroGNN(conv1, conv2, hetero, 32).to(args.device)
model.load_state_dict(torch.load("./models/modelDeepSnapPred", map_location=torch.device(args.device)))
model = model.to(args.device)
model.eval()
edge = ('disorder', 'dis_dru_the', 'drug')
toInfer, _ = getOriginal()
print("Started getting top", n, "at", datetime.now().strftime("%H:%M:%S"))
topN = getTopNDS(model, original, toInfer, edge, n)
print("Finished getting top",n, "at", datetime.now().strftime("%H:%M:%S"))
if __name__ == '__main__':
deepSnap()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment