Commit aa6f5904 authored by ADRIAN  AYUSO MUNOZ's avatar ADRIAN AYUSO MUNOZ

Update Files

parent 5d4b10ae
......@@ -19,7 +19,7 @@ from deepsnap.hetero_gnn import (
import warnings
warnings.filterwarnings("ignore")
edges = [('disorder', 'dis_dru_the', 'drug')]
edges = [('phenotype', 'dis_dru_the', 'drug')]
device = 'cuda'
# ---------------------------
......@@ -95,21 +95,21 @@ class HeteroGNN(torch.nn.Module):
for message_type in edges:
nodes_first = None
nodes_second = None
if message_type == ('disorder', 'dis_dru_the', 'drug'):
nodes_first = torch.index_select(x['disorder'], 0, data.edge_label_index[message_type][0, :].long())
if message_type == ('phenotype', 'dis_dru_the', 'drug'):
nodes_first = torch.index_select(x['phenotype'], 0, data.edge_label_index[message_type][0, :].long())
nodes_second = torch.index_select(x['drug'], 0, data.edge_label_index[message_type][1, :].long())
elif message_type == ('drug', 'dru_dis_the', 'disorder'):
elif message_type == ('drug', 'dru_dis_the', 'phenotype'):
nodes_first = torch.index_select(x['drug'], 0, data.edge_label_index[message_type][0, :].long())
nodes_second = torch.index_select(x['disorder'], 0, data.edge_label_index[message_type][1, :].long())
nodes_second = torch.index_select(x['phenotype'], 0, data.edge_label_index[message_type][1, :].long())
elif message_type == ('disorder', 'dse_sym', 'disorder'):
nodes_first = torch.index_select(x['disorder'], 0, data.edge_label_index[message_type][0, :].long())
nodes_second = torch.index_select(x['disorder'], 0, data.edge_label_index[message_type][1, :].long())
elif message_type == ('phenotype', 'dse_sym', 'phenotype'):
nodes_first = torch.index_select(x['phenotype'], 0, data.edge_label_index[message_type][0, :].long())
nodes_second = torch.index_select(x['phenotype'], 0, data.edge_label_index[message_type][1, :].long())
elif message_type == ('disorder', 'sym_dse', 'disorder'):
nodes_first = torch.index_select(x['disorder'], 0, data.edge_label_index[message_type][0, :].long())
nodes_second = torch.index_select(x['disorder'], 0, data.edge_label_index[message_type][1, :].long())
elif message_type == ('phenotype', 'sym_dse', 'phenotype'):
nodes_first = torch.index_select(x['phenotype'], 0, data.edge_label_index[message_type][0, :].long())
nodes_second = torch.index_select(x['phenotype'], 0, data.edge_label_index[message_type][1, :].long())
pred[message_type] = torch.sigmoid(torch.sum(nodes_first * nodes_second, dim=-1))
......@@ -122,8 +122,8 @@ class HeteroGNN(torch.nn.Module):
for message_type in edges:
nodes_first = None
nodes_second = None
if message_type == ('disorder', 'dis_dru_the', 'drug'):
nodes_first = x['disorder']
if message_type == ('phenotype', 'dis_dru_the', 'drug'):
nodes_first = x['phenotype']
nodes_second = x['drug']
for i, elem in enumerate(nodes_first):
......@@ -137,17 +137,17 @@ class HeteroGNN(torch.nn.Module):
for message_type in edges:
nodes_first = None
nodes_second = None
if message_type == ('disorder', 'dis_dru_the', 'drug') and type == 'disease':
nodes_first = x['disorder'][id].unsqueeze(0)
if message_type == ('phenotype', 'dis_dru_the', 'drug') and type == 'phenotype':
nodes_first = x['phenotype'][id].unsqueeze(0)
nodes_second = x['drug']
elif message_type == ('disorder', 'dis_dru_the', 'drug') and type == 'drug':
nodes_first = x['disorder']
elif message_type == ('phenotype', 'dis_dru_the', 'drug') and type == 'drug':
nodes_first = x['phenotype']
nodes_second = x['drug'][id].unsqueeze(0)
elif message_type == ('disorder', 'dse_sym', 'disorder'):
nodes_first = x['disorder']
nodes_second = x['disorder']
elif message_type == ('phenotype', 'dse_sym', 'phenotype'):
nodes_first = x['phenotype']
nodes_second = x['phenotype']
for i, elem in enumerate(nodes_first):
pred[message_type, i] = torch.sigmoid(torch.sum(elem * nodes_second, dim=-1))
......@@ -155,7 +155,7 @@ class HeteroGNN(torch.nn.Module):
def pred(self, data, eid):
x = self.getEmbeddings(data, False)
heads = x['disorder']
heads = x['phenotype']
tails = x['drug']
pred = []
for head, tail in zip(eid[0], eid[1]):
......
......@@ -53,8 +53,8 @@ class DISNETConstructor:
dis, dru, pat, pro, ddi = self.getNodeInfo(full)
nodes = [dis, dru, pat, pro, ddi]
# Store types and its size.
ntypes = ['disorder', 'drug', 'pathway', 'protein', 'drug-drug-interaction']
nsizes = {'disorder': len(dis.index),
ntypes = ['phenotype', 'drug', 'pathway', 'protein', 'drug-drug-interaction']
nsizes = {'phenotype': len(dis.index),
'drug': len(dru.index),
'pathway': len(pat.index),
'protein': len(pro.index),
......@@ -65,8 +65,8 @@ class DISNETConstructor:
dis, dru = self.getNodeInfo(full)
nodes = [dis, dru]
# Store types and its size.
ntypes = ['disorder', 'drug']
nsizes = {'disorder': len(dis.index),
ntypes = ['phenotype', 'drug']
nsizes = {'phenotype': len(dis.index),
'drug': len(dru.index)
}
......@@ -83,8 +83,8 @@ class DISNETConstructor:
# Adding NID to nodes
dis['NID'] = dis.index
dis['node_type'] = 'disorder'
dis['node_id'] = nodes_flat.loc[nodes_flat['node_type'] == 'disorder'].reset_index(drop=True).node_id
dis['node_type'] = 'phenotype'
dis['node_id'] = nodes_flat.loc[nodes_flat['node_type'] == 'phenotype'].reset_index(drop=True).node_id
dru['NID'] = dru.index
dru['node_type'] = 'drug'
......@@ -94,7 +94,7 @@ class DISNETConstructor:
dis_dict = dis[['id', 'node_id']].set_index('id').to_dict()['node_id']
dru_dict = dru[['id', 'node_id']].set_index('id').to_dict()['node_id']
dis_feat = torch.tensor([[1] * 100] * nsizes['disorder'], dtype=torch.float32)
dis_feat = torch.tensor([[1] * 100] * nsizes['phenotype'], dtype=torch.float32)
dru_feat = torch.tensor([[1] * 100] * nsizes['drug'], dtype=torch.float32)
if full:
......@@ -121,11 +121,11 @@ class DISNETConstructor:
ddi_feat = torch.tensor([[1] * 100] * nsizes['drug-drug-interaction'], dtype=torch.float32)
feats = {'disorder': dis_feat, 'drug': dru_feat, 'pathway': pat_feat, 'protein': pro_feat,
feats = {'phenotype': dis_feat, 'drug': dru_feat, 'pathway': pat_feat, 'protein': pro_feat,
'drug-drug-interaction': ddi_feat}
else:
feats = {'disorder': dis_feat, 'drug': dru_feat}
feats = {'phenotype': dis_feat, 'drug': dru_feat}
# Add nodes to the graph
G = nx.DiGraph()
......@@ -240,17 +240,17 @@ class DISNETConstructor:
'dru_ddi': ddi_dru[['druNID', 'ddiNID']].values.tolist()
}
edges_dict = {'dis_dru_the': ('disorder', 'drug'),
'dru_dis_the': ('drug', 'disorder'),
edges_dict = {'dis_dru_the': ('phenotype', 'drug'),
'dru_dis_the': ('drug', 'phenotype'),
'dis_sym': ('disorder', 'disorder'),
'sym_dis': ('disorder', 'disorder'),
'dis_sym': ('phenotype', 'phenotype'),
'sym_dis': ('phenotype', 'phenotype'),
'dis_pat': ('disorder', 'pathway'),
'pat_dis': ('pathway', 'disorder'),
'dis_pat': ('phenotype', 'pathway'),
'pat_dis': ('pathway', 'phenotype'),
'dis_pro': ('disorder', 'protein'),
'pro_dis': ('protein', 'disorder'),
'dis_pro': ('phenotype', 'protein'),
'pro_dis': ('protein', 'phenotype'),
'druA_druB': ('drug', 'drug'),
'druB_druA': ('drug', 'drug'),
......@@ -258,11 +258,11 @@ class DISNETConstructor:
'dru_pro': ('drug', 'protein'),
'pro_dru': ('protein', 'drug'),
'dru_sym_ind': ('drug', 'disease'),
'sym_dru_ind': ('disease', 'drug'),
'dru_sym_ind': ('drug', 'phenotype'),
'sym_dru_ind': ('phenotype', 'drug'),
'dru_sym_sef': ('drug', 'disease'),
'sym_dru_sef': ('disease', 'drug'),
'dru_sym_sef': ('drug', 'phenotype'),
'sym_dru_sef': ('phenotype', 'drug'),
'pro_pat': ('protein', 'pathway'),
'pat_pro': ('pathway', 'protein'),
......@@ -270,11 +270,11 @@ class DISNETConstructor:
'proA_proB': ('protein', 'protein'),
'proB_proA': ('protein', 'protein'),
'ddi_phe': ('drug-drug-interaction', 'disorder'),
'phe_ddi': ('disorder', 'drug-drug-interaction'),
'ddi_phe': ('drug-drug-interaction', 'phenotype'),
'phe_ddi': ('phenotype', 'drug-drug-interaction'),
'ddi_dru': ('drug-drug-interaction', 'dru'),
'dru_ddi': ('dru', 'drug-drug-interaction'),
'ddi_dru': ('drug-drug-interaction', 'drug'),
'dru_ddi': ('drug', 'drug-drug-interaction'),
}
else:
edges = {
......@@ -286,19 +286,20 @@ class DISNETConstructor:
}
edges_dict = {'dis_dru_the': ('disorder', 'drug'),
'dru_dis_the': ('drug', 'disorder'),
edges_dict = {'dis_dru_the': ('phenotype', 'drug'),
'dru_dis_the': ('drug', 'phenotype'),
'dis_sym': ('disorder', 'disorder'),
'sym_dis': ('disorder', 'disorder')
'dis_sym': ('phenotype', 'phenotype'),
'sym_dis': ('phenotype', 'phenotype')
}
for edge_t in edges_dict.keys():
for edge in edges[edge_t]:
try:
G.add_edge(int(edge[0]), int(edge[1]), edge_feature=edge[2], edge_type=edge_t)
except IndexError:
G.add_edge(int(edge[0]), int(edge[1]), edge_type=edge_t)
G.add_edge(int(edge[0]), int(edge[1]), edge_feature=1, edge_type=edge_t)
# --------------------------
# HETEROGRAPGH
......@@ -325,7 +326,7 @@ class DISNETConstructor:
# Nodes data pre-processing (mapping and data for the graph)
nodes_flat = pd.concat([dis, dru],
keys=['disorder', 'drug'],
keys=['phenotype', 'drug'],
names=['node_type', 'NID']).reset_index()
nodes_flat['node_id'] = nodes_flat.index
......@@ -336,8 +337,8 @@ class DISNETConstructor:
# Adding NID to nodes
dis['NID'] = dis.index
dis['node_type'] = 'disorder'
dis['node_id'] = nodes_flat.loc[nodes_flat['node_type'] == 'disorder'].reset_index(drop=True).node_id
dis['node_type'] = 'phenotype'
dis['node_id'] = nodes_flat.loc[nodes_flat['node_type'] == 'phenotype'].reset_index(drop=True).node_id
dru['NID'] = dru.index
dru['node_type'] = 'drug'
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment