diff --git a/disease_src/disease-disease-simple.py b/disease_src/disease-disease-simple.py deleted file mode 100644 index 6ed13969caf9b1ed6ec9146cd8be32dc2f7fee6a..0000000000000000000000000000000000000000 --- a/disease_src/disease-disease-simple.py +++ /dev/null @@ -1,2937 +0,0 @@ -import dgl -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F -import tqdm -from dgl.data.knowledge_graph import FB15k237Dataset -from dgl.dataloading import GraphDataLoader -from dgl.nn.pytorch import RelGraphConv -import pandas as pd -import math -from sklearn.metrics import roc_auc_score -import sklearn.metrics as metrics -import wandb -import random -import dgl.nn.pytorch as dglnn -from matplotlib import pyplot as plt -from statistics import mean -from get_embedding_prot import export_protein_emb -import pysmiles.read_smiles as read_smiles -import os -from models import GS,MLPPredS -import mysql.connector -import numpy as np -from collections import defaultdict -import scipy.stats as st -import matplotlib.pyplot as plt -import dgl.function as fn -import heterograph_construction -from deepsnap.hetero_gnn import forward_op -import torch.multiprocessing as mp -from concurrent.futures import ProcessPoolExecutor - -def read_drug_info(): - """ - Function to read drug information from a MySQL database and save it as a DataFrame. - Returns: - pandas.DataFrame: DataFrame containing drug information. - """ - if not os.path.exists("data/druStruc.tsv"): - - cnx = mysql.connector.connect(user='edsss_usr', password='1AftIRWJa93P', - host='ares.ctb.upm.es', port='30100', - database='disnet_drugslayer') - cursor = cnx.cursor() - - query = ("SELECT drug_id, drug_name, chemical_structure FROM drug;") - cursor.execute(query) - - newDf = pd.DataFrame(cursor.fetchall(),columns=["id","name","struc"]) - cursor.close() - cnx.close() - newDf.to_csv("data/druStruc.tsv", sep='\t', index=False) - - - else: - newDf= pd.read_csv("data/druStruc.tsv",sep='\t') - return newDf - -def generate_drug_embedding(device,in_dim,h_dim,mask,epochs=100): - """ - Function to generate drug embeddings using graph neural networks. - Args: - device (torch.device): Device to use for computations (e.g., 'cuda' or 'cpu'). - in_dim (int): Input dimension for the model. - h_dim (int): Hidden dimension for the model. - mask (torch.Tensor): Mask for specifying train, validation, and test data. - epochs (int): Number of epochs for training the model (default is 100). - Returns: - pandas.DataFrame: DataFrame containing drug embeddings. - """ - pos_train_g,pos_val_g,pos_test_g=dict(),dict(),dict() - neg_train_g,neg_val_g,neg_test_g=dict(),dict(),dict() - drug_str=read_drug_info() - drug_str[['embed','neg_struc']]=pd.DataFrame([[np.nan, np.nan]], index=drug_str.index) - for i,row in drug_str.iterrows(): - if(row['struc'] != '0' and type(row['struc']) is str): - ss=row["struc"] - sm=read_smiles(row["struc"]).to_directed() - row["struc"]=dgl.from_networkx(sm) - if(row["struc"].num_nodes()==1): - ssmile=ss - restruct=sm - row["struc"].ndata['h']=torch.ones(row['struc'].num_nodes(),in_dim) - row["struc"].ndata['feat']=row["struc"].ndata['h'] - else: - row["struc"]= 0 - - for i,row in drug_str.iterrows(): - if(row["struc"] != 0 and row["struc"].num_edges()>3): - ns=NegativeSamplerSim(row["struc"],1) - #print(row["struc"]) - row["neg_struc"]=ns(row["struc"]) - if(row["neg_struc"] < 3): - row["struc"]=0 - #print(row["struc"]) - pos_train_g[row['name']],pos_val_g[row['name']],pos_test_g[row['name']]= get_subsets_s(row["struc"], mask) - neg_train_g[row['name']],neg_val_g[row['name']],neg_test_g[row['name']]= get_subsets_s(row["neg_struc"], mask) - else: - row["struc"]=0 - model= GS(in_dim,h_dim) - pred=MLPPredS(h_dim) - optimizer= torch.optim.AdamW(list(model.parameters())+list(pred.parameters()), lr=1e-5, betas=(0.9, 0.999), eps=1e-08, ) - - - model=model.to(device) - model.train() - for jj in tqdm.tqdm(range(epochs),total=epochs): - loss=0 - for i,row in drug_str.iterrows(): - if(row['struc']!=0): - loss+=train_d_e(model,pos_train_g[row['name']],neg_train_g[row['name']],pos_val_g[row['name']],neg_val_g[row['name']],pred,pos_train_g[row['name']].ndata['feat']) - optimizer.zero_grad() - loss.backward() - optimizer.step() - model.eval() - for i,row in drug_str.iterrows(): - if(row["struc"] != 0): - row["embed"] = model(row['struc'],embed) - else: - row["embed"] = torch.tensor([1]*h_dim) - drug_str.drop(['struc','neg_struc'], axis=1) - drug_str.to_csv("embeddings/drug_emb.tsv",sep='\t') - return drug_str -def train_d_e(model,pos_train_g,neg_train_g,pos_val_g,neg_val_g,pred,embed): - """ - Training function for drug embedding model. - Args: - model: Graph neural network model. - pos_train_g: Positive training graph. - neg_train_g: Negative training graph. - pos_val_g: Positive validation graph. - neg_val_g: Negative validation graph. - pred: Predictor model. - embed: Embedding vector. - Returns: - torch.Tensor: Loss value. - """ - h=model(pos_train_g,embed) - # each row in the triplets is a 3-tuple of (source, relation, destination) - score=pred(pos_train_g,h) - neg_score=pred(neg_train_g,h) - labelsss=np.zeros(len(score)+len(neg_score), dtype=np.float32) - labelsss[:len(score)]=1 - val_score=pred(pos_train_g,h) - val_neg_score=pred(neg_train_g,h) - predict_loss = F.binary_cross_entropy_with_logits(torch.cat((score,neg_score),dim=0), torch.from_numpy(labelsss)) - labelsss=np.zeros(len(val_score)+len(val_neg_score), dtype=np.float32) - val_loss = F.binary_cross_entropy_with_logits(torch.cat((val_score,val_neg_score),dim=0), torch.from_numpy(labelsss)) - #print(predict_loss) - return torch.stack([predict_loss,val_loss]).mean(dim=0) -def calc_d_e_score(): - """ - Calculate the score for drug embedding. - - Args: - triplets (torch.Tensor): Tensor containing triplets (edges). - g (dgl.DGLHeteroGraph): Graph data. - embedding (torch.Tensor): Embedding vector. - etypeu (str): Edge type. - score (torch.Tensor): Previous score tensor. - - Returns: - torch.Tensor: Score tensor. - torch.Tensor: Labels tensor. - """ - edges=(triplets[:, 0],triplets[:, 1]) - score_graph=dgl.heterograph(edges,num_nodes_dict={ntype: g.num_nodes(ntype) for ntype in g.ntypes}) - score_graph.ndata['h']=embedding - if etypeu is not None: - for etype in g.etypes: - if score is not None: - score=torch.cat((score,self.Pred(score_graph,embedding,etype)),dim=0) - labelss=torch.cat((labelss,labels[etype]),dim=0) - else: - score=self.Pred(score_graph,embedding,etype) - labelss=labels[etype] - else: - score=self.Pred(score_graph,embedding,etype) - labelss=labels[etype] - return score,labelss - -def loadNodes(full): - """ - Load nodes from files. - - Args: - full (bool): Flag indicating whether to load full data. - - Returns: - pd.DataFrame: Dataframe containing disease nodes. - pd.DataFrame: Dataframe containing drug nodes. - (Optional) pd.DataFrame: Dataframe containing patient nodes. - (Optional) pd.DataFrame: Dataframe containing protein nodes. - (Optional) pd.DataFrame: Dataframe containing DDI nodes. - """ - dis = pd.read_csv('data/nodes/dis.tsv', sep='\t') - dru = pd.read_csv('data/nodes/dru.tsv', sep='\t') - - if full: - pat = pd.read_csv('data/nodes/pat.tsv', sep='\t') - pro = pd.read_csv('data/nodes/pro.tsv', sep='\t') - ddi = pd.read_csv('data/nodes/ddi.tsv', sep='\t') - - return dis, dru, pat, pro, ddi - - else: - return dis, dru - -# for building training/testing graphs -def load_data(full): - """ - Load data for building training/testing graphs. - - Args: - full (bool): Flag indicating whether to load full data. - - Returns: - pd.DataFrame: Dataframe containing disease-drug-theatre links. - pd.DataFrame: Dataframe containing disease-symptom links. - (Optional) pd.DataFrame: Dataframe containing disease-patient links. - (Optional) pd.DataFrame: Dataframe containing disease-protein links. - (Optional) pd.DataFrame: Dataframe containing drug-drug links. - (Optional) pd.DataFrame: Dataframe containing drug-protein links. - (Optional) pd.DataFrame: Dataframe containing drug-symptom (indication) links. - (Optional) pd.DataFrame: Dataframe containing drug-symptom (side effect) links. - (Optional) pd.DataFrame: Dataframe containing protein-patient links. - (Optional) pd.DataFrame: Dataframe containing protein-protein links. - (Optional) pd.DataFrame: Dataframe containing DDI-phenotype links. - (Optional) pd.DataFrame: Dataframe containing DDI-drug links. - """ - dis_dru_the = pd.read_csv('data/links/dis_dru_the.tsv', sep='\t') - dis_sym = pd.read_csv('data/links/dis_sym.tsv', sep='\t') - - if full: - dis_pat = pd.read_csv('data/links/dis_pat.tsv', sep='\t') - dis_pro = pd.read_csv('data/links/dis_pro.tsv', sep='\t') - dru_dru = pd.read_csv('data/links/dru_dru.tsv', sep='\t') - dru_pro = pd.read_csv('data/links/dru_pro.tsv', sep='\t') - dru_sym_ind = pd.read_csv('data/links/dru_sym_ind.tsv', sep='\t') - dru_sym_sef = pd.read_csv('data/links/dru_sym_sef.tsv', sep='\t') - pro_pat = pd.read_csv('data/links/pro_pat.tsv', sep='\t') - pro_pro = pd.read_csv('data/links/pro_pro.tsv', sep='\t') - ddi_phe = pd.read_csv('data/links/ddi_phe.tsv', sep='\t') - ddi_dru = pd.read_csv('data/links/ddi_dru.tsv', sep='\t') - - return dis_dru_the, dis_sym, dis_pat, dis_pro, dru_dru, dru_pro, dru_sym_ind, dru_sym_sef, pro_pat, \ - pro_pro, ddi_phe, ddi_dru - - - else: - return dis_dru_the, dis_sym - -def create_heterograph(full): - """ - Create a heterograph from the provided data. - - Args: - full (bool): Flag indicating whether to load full data. - - Returns: - dgl.DGLHeteroGraph: Heterogeneous graph. - list: List of dictionaries containing node mappings. - pd.DataFrame: Repository database. - """ - dict_edges={} - num_nodes={} - edges_weights={} - edges_w={} - if full: - dis_dru_the, dis_sym, dis_pat, dis_pro, dru_dru, dru_pro, dru_sym_ind, dru_sym_sef, pro_pat, \ - pro_pro, ddi_phe, ddi_dru = load_data(full) - dis, dru, pat, pro, ddi = loadNodes(full) - num_nodes={} - - dis=dis.rename(lambda x: "node_id" if x != "name" else x, axis=1) - dru=dru.rename(lambda x: "node_id" if x != "name" else x, axis=1) - pat=pat.rename(lambda x: "node_id" if x != "name" else x, axis=1) - pro=pro.rename(lambda x: "node_id" if x != "name" else x, axis=1) - ddi=ddi.rename(lambda x: "node_id" if x != "name" else x, axis=1) - - dis=dis.drop_duplicates(keep='last').reset_index(drop=True).rename_axis("id").reset_index() - dru=dru.drop_duplicates(keep='last').reset_index(drop=True).rename_axis("id").reset_index() - pat=pat.drop_duplicates(keep='last').reset_index(drop=True).rename_axis("id").reset_index() - pro=pro.drop_duplicates(keep='last').reset_index(drop=True).rename_axis("id").reset_index() - ddi=ddi.drop_duplicates(keep='last').reset_index(drop=True).rename_axis("id").reset_index() - - - num_nodes["protein"]= pro.shape[0] - num_nodes["drug_drug_interaction"]= ddi.shape[0] - num_nodes["pathway"]=pat.shape[0] - - dis_dict= dis[["id","node_id"]].set_index("node_id").to_dict()["id"] - dru_dict= dru[["id","node_id"]].set_index("node_id").to_dict()["id"] - pat_dict= pat[["id","node_id"]].set_index("node_id").to_dict()["id"] - pro_dict= pro[["id","node_id"]].set_index("node_id").to_dict()["id"] - ddi_dict= ddi[["id","node_id"]].set_index("node_id").to_dict()["id"] - - - - #print(dis[["id","node_id"]].set_index("node_id").to_dict()) - - dis_pat['disID'] = dis_pat.dis.map(dis_dict) - dis_pat['patID'] = dis_pat.pat.map(pat_dict) - - #print(dis_pro) - - dis_pro['disID'] = dis_pro.dis.map(dis_dict) - dis_pro['proID'] = dis_pro.pro.map(pro_dict) - - dru_dru['druAID'] = dru_dru.drA.map(dru_dict) - dru_dru['druBID'] = dru_dru.drB.map(dru_dict) - - dru_pro['druID'] = dru_pro.dru.map(dru_dict) - dru_pro['proID'] = dru_pro.pro.map(pro_dict) - - dru_sym_ind['druID'] = dru_sym_ind.dru.map(dru_dict) - dru_sym_ind['symID'] = dru_sym_ind.sym.map(dis_dict) - - dru_sym_sef['druID'] = dru_sym_sef.dru.map(dru_dict) - dru_sym_sef['symID'] = dru_sym_sef.sym.map(dis_dict) - - pro_pat['proID'] = pro_pat.pro.map(pro_dict) - - pro_pat['patID'] = pro_pat.pat.map(pat_dict) - - pro_pro['proAID'] = pro_pro.prA.map(pro_dict) - pro_pro['proBID'] = pro_pro.prB.map(pro_dict) - - ddi_phe['ddiID'] = ddi_phe.ddi.map(ddi_dict) - ddi_phe['disID'] = ddi_phe.phe.map(dis_dict) - - ddi_dru['ddiID'] = ddi_dru.ddi.map(ddi_dict) - ddi_dru['druID'] = ddi_dru.dru.map(dru_dict) - - - - dict_edges[("drug_drug_interaction", "ddi_dru", "drug")]=(torch.tensor(ddi_dru["ddiID"].to_list()),torch.tensor(ddi_dru["druID"].to_list())) - dict_edges[("drug_drug_interaction", "ddi_phe", "disease")]=(torch.tensor(ddi_phe['ddiID'].to_list()),torch.tensor(ddi_phe['disID'].to_list())) - dict_edges[("protein", "pro_pro", "protein")]=(torch.tensor(pro_pro['proAID'].to_list()),torch.tensor(pro_pro["proBID"].to_list())) - dict_edges[("protein", "pro_pat", "pathway")]=(torch.tensor(pro_pat['proID'].to_list()),torch.tensor(pro_pat["patID"].to_list())) - dict_edges[("drug", "dru_sym_sef", "disease")]=(torch.tensor(dru_sym_sef['druID'].to_list()),torch.tensor(dru_sym_sef["symID"].to_list())) - dict_edges[("drug", "dru_sym_ind", "disease")]=(torch.tensor(dru_sym_ind['druID'].to_list()),torch.tensor(dru_sym_ind["symID"].to_list())) - dict_edges[("drug", "dru_pro", "protein")]=(torch.tensor(dru_pro["druID"].to_list()),torch.tensor(dru_pro["proID"].to_list())) - dict_edges[("drug", "dru_dru", "drug")]=(torch.tensor(dru_dru['druAID'].to_list()),torch.tensor(dru_dru['druBID'].to_list())) - dict_edges[("disease", "dis_pro", "protein")]=(torch.tensor(dis_pro["disID"].to_list()),torch.tensor(dis_pro['proID'].to_list())) - dict_edges[("disease", "dis_pat", "pathway")]=(torch.tensor(dis_pat["disID"].to_list()),torch.tensor(dis_pat["patID"].to_list())) - edges_weights["dis_pro"]=dis_pro["w"].to_list() - edges_weights["dru_sym_sef"]=dru_sym_sef["w"].to_list() - - else: - dis_dru_the, dis_sym= load_data(full) - dis, dru = loadNodes(full) - - dis=dis.rename(lambda x: "node_id" if x != "name" else x, axis=1) - dru=dru.rename(lambda x: "node_id" if x != "name" else x, axis=1) - - dis=dis.drop_duplicates(keep='last').reset_index(drop=True).rename_axis("id").reset_index() - dru=dru.drop_duplicates(keep='last').reset_index(drop=True).rename_axis("id").reset_index() - - dis_dict= dis[["id","node_id"]].set_index("node_id").to_dict()["id"] - dru_dict= dru[["id","node_id"]].set_index("node_id").to_dict()["id"] - - num_nodes["disease"]= dis.shape[0] - num_nodes["drug"]= dru.shape[0] - - dis_sym['disID'] = dis_sym.dis.map(dis_dict) - dis_sym['symID'] = dis_sym.sym.map(dis_dict) - repo_db=pd.read_csv("testData/drugdis_repodb_ALLlinks.tsv",sep='\t') - repo_db['disID'] = repo_db.dis.map(dis_dict) - repo_db['druID'] = repo_db.dru.map(dru_dict) - merged_df=pd.merge(dis_dru_the, repo_db, left_on=['dis', 'dru'], right_on=['dis', 'dru'], how='left', indicator=True) - dis_dru_the_C = dis_dru_the.copy() - dis_dru_the = merged_df[merged_df['_merge'] == 'left_only'] - -# Drop the indicator column and reset the index if needed - dis_dru_the = dis_dru_the.drop(columns=['_merge']).reset_index(drop=True) - dis_dru_the["disID"]= dis_dru_the.dis.map(dis_dict) - dis_dru_the["druID"]= dis_dru_the.dru.map(dru_dict) - dict_edges[("disease", "dis_dru_the", "drug")]=(torch.tensor(dis_dru_the["disID"].to_list()),torch.tensor(dis_dru_the["druID"].to_list())) - dict_edges[("disease", "dis_sym", "disease")]=(torch.tensor(dis_sym["disID"].to_list()),torch.tensor(dis_sym["symID"].to_list())) - g=dgl.heterograph(dict_edges,num_nodes_dict=num_nodes) - #print((torch.tensor(dis_dru_the["disID"].to_list()+repo_db['disID'].to_list()),torch.tensor(dis_dru_the["druID"].to_list()+repo_db['druID'].to_list()))) - dict_edges[("disease", "dis_dru_the", "drug")]=(torch.tensor(dis_dru_the["disID"].to_list()+repo_db['disID'].to_list()),torch.tensor(dis_dru_the["druID"].to_list()+repo_db['druID'].to_list())) - comp_g=dgl.heterograph(dict_edges,num_nodes_dict=num_nodes) - g.ndata['h']={ntype: torch.tensor([[1]*400 for u in range(0,g.num_nodes(ntype))]) for ntype in g.ntypes} - comp_g.ndata['h']={ntype: torch.tensor([[1]*400 for u in range(0,comp_g.num_nodes(ntype))]) for ntype in comp_g.ntypes} - dict_w=dict() - for etype in comp_g.etypes: - - edges_weights[etype]=torch.tensor([1]*comp_g.num_edges(etype)) if etype not in edges_weights else torch.tensor(edges_weights[etype]) - edges_w[etype]=edges_weights[etype] - edges_w["rev_"+etype]=edges_weights[etype] - src,dst=comp_g.edges(etype=etype) - dataf=pd.DataFrame({'src':src,'dst':dst,'w':edges_w[etype]}) - dataf['src'],dataf['dst']=dataf['src'].astype(str),dataf['dst'].astype(str) - dict_w[etype]=dataf.set_index(['src', 'dst'])['w'].to_dict() - dict_w["rev_"+etype]=dataf.set_index(['dst', 'src'])['w'].to_dict() - comp_g.edata['w']=edges_weights - if(full): - return g,[dru_dict,pro_dict],repo_db,comp_g,dict_w - else: - return g,[dru_dict],repo_db,comp_g,dict_w -def create_distances_subgraph(graph,r_p=5,path_pn=20,ntype="disease"): - """ - Create a heterograph from the provided data. - - Args: - full (bool): Flag indicating whether to load full data. - - Returns: - dgl.DGLHeteroGraph: Heterogeneous graph. - list: List of dictionaries containing node mappings. - pd.DataFrame: Repository database. - """ - Metrics,true_paths=random_dst(graph,ntype,r_p=r_p,path_n=path_pn) - subgraph=dgl.graph(true_paths,num_nodes=graph.num_nodes(ntype)) - return Metrics,subgraph -def random_dst(graph,ntype,r_p=5,path_n=20): - metrics=[] - paths=[] - gra=dgl.heterograph({etype: graph.edges(etype) for etype in graph.etypes},num_nodes={stype: graph.num_nodes(stype) for stype in graph.ntypes}) - - repeat=True - while repeat: - for n in gra.nodes(ntype): - list=[[0] for gra.nodes(ntype)] - u,d=random_walk_path(gra,n,ntype,r_p,path_n) - for i in range(0,len(u)): - list[u[i]].append(d[i]) - if len(u)>0: - repeat=False - metrics.append(list) - path.append([i for i,v in enumerate(list) if v != [0]] - return metrics,torch.tensor(paths) -def random_walk_path(g,n,ntype,r_p,path_n): - final_path=dict() - xNode=n - for i in r_p: - yNode=dgl.sampling.random_walk(g,[xNode],length=path_n) - exist_p,path,node=cut_path(ntype,g,yNode) - if exist_p: - if node in final_path: - final_path[node]=final_path[node].expand(max(final_path[node].shape[0],path.shape[0]))*1/final_path[node].shape[0]+path.expand(max(final_path[node].shape[0],path.shape[0]))*1/path.shape[0] - else: - path_emb=[0 for in range(len(g.etypes))] - for u in range(0,len(path)): - path_emb[path[u]]+=u*len(g.etypes)+path[u] - final_path[node]=torch.tensor(path_emb) - return list(final_path.keys()),list(final_path.items()) -def cut_path(ntype,g,NodesPath): - path=NodesPath[0].squeeze() - etyp=NodesPath[1].squeeze() - broken=False - break_id=None - for i in range(0,len(etyp)): - if g.to_canonical_etype(g.etypes[etyp[i]])[2]==ntype: - break_id=i - broken=True - break - emb=None - if broken: - emb=path[:break_id] - - return broken,path[:break_id],emb - return broken, break_id, emb -def get_incidence(g,ntype,applySubcosts=None,k=0.001): - metrics = [] - paths = [] - - # Get initial etypes and adjacency matrices - etypes, etype_mat = get_etype_subtypes(g, ntype, applySubcosts) - last_etypes = set(etypes) - - # Initialize paths with the first set of etypes - for etype in etypes: - if etype[2] != ntype: - paths.append([[etype], {etype: [etype_mat[(etype[1],etype[2])]]}]) - else: - metrics.append({"path": [etype], "mat": etype_mat[(etype[1],etype[2])]}) - - def parallel_matmul(etype_pair): - prev_mat, adj_mat = etype_pair - #return torch.matmul(prev_mat, adj_mat) - return prev_mat @ adj_mat - while last_etypes: - new_etypes=[] - edg,adj_mat={},{} - new_last_etypes=set() - for ety in last_etypes: - e,ad=get_etype_subtypes(g,ety,applySubcosts) - edg[ety]=e - adj_mat[ety]=ad - new_paths=[] - with ProcessPoolExecutor(max_workers=mp.cpu_count()) as executor: - future_to_path = {} - for path in paths: - for etype in edg[path[0][-1][2]]: - prev_mat = path[1][path[0][-1]][-1] - adj_matrix = adj_mat[etype[1],etype[2]] - - # Parallelize the matrix multiplication - future = executor.submit(parallel_matmul, (prev_mat, adj_matrix)) - future_to_path[future] = (path, etype) - for future in future_to_path: - path, etype = future_to_path[future] - new_mat = future.result() - if etype[2] == ntype: - path_emb = [0 for in range(len(g.etypes))] - repath=path - repath.append(etype) - dict_etypes={v:c for (c,v) in enumerate(g.etypes)} - for u in range(len(repath)): - path_emb[dict_etypes[repath[u]]] += u * len(g.etypes) + repath[u] - metrics.append({"path":path,"mat":new_mat,"path_emb":path_emb} - else: - new_mat_simp=new_mat.clone() - new_mat_simp[new_mat > 1]=1 - not_add=False - for a_mat in path[1][etype]: - x_mat=a_mat.clone() - x_mat[a_mat > 1] = 1 - if torch.equal(new_mat_simp, x_mat): - not_add=True - break - if not not_add: - new_last_etypes.add(etype[2]) - new_path = [list(path[0]) + [etype], dict(path[1])] - if etype in new_path[1]: - new_path[1][etype].append(new_mat) - else: - new_path[1][etype]= [new_mat] - new_paths.append(new_path) - last_etypes=new_last_etypes - paths=new_paths - matrix_overload=None - metrics_overloard= None - if len(metrics) < 1: - return None, None - for u in metrics: - matrix_overload= u['mat'].todense()/len(u['path']) if matrix_overload is None else matrix_overload + u['mat'].todense()/len(u['path']) - metrics_overloard = (u['mat'].todense()*u['path_emb']).sum(dim=1) if matrix_overload is None else matrix_overload + u['mat'].todense()/len(u['path']) - matrix_overload=matrix_overload/len(metrics) - for y,u in matrix_overload.shape: - if y==u: - matrix_overload[y,:]=matrix_overload[y,:]/(matrix_overload[y,u]) - - matrix_overload = torch.from_numpy(matrix_overload) - matrix_overload=F.sigmoid(matrix_overload/k) - return matrix_overload,metrics_overloard -def get_etype_subtypes(g,ntype,applySubcosts): - adj_mat={} - xtype=[] - for srctype,etype,dsttype in g.etypes: - if srctype==ntype: - xtype.append((srctype,etype,dsttype)) - if applySubcosts is not None: - adj_mat[(etype,dsttype)]=g.adjancency_matrix(etype)*applySubcosts[etype] - else: - adj_mat[(etype,dsttype)]=g.adjancency_matrix(etype) - return xtype,adj_mat -def get_subset_g(g, mask, num_rels, bidirected=False): - """ - Get a subset of the graph based on the provided mask. - - Args: - g (dgl.DGLHeteroGraph): Original graph. - mask (torch.Tensor): Mask indicating the subset of edges to be selected. - num_rels (int): Number of relation types. - bidirected (bool): Flag indicating whether the graph is bidirected. - - Returns: - dgl.DGLHeteroGraph: Subset of the original graph. - """ - src, dst = g.edges() - sub_src = src[mask] - sub_dst = dst[mask] - sub_rel = g.edata["etype"][mask] - - if bidirected: - sub_src, sub_dst = torch.cat([sub_src, sub_dst]), torch.cat( - [sub_dst, sub_src] - ) - sub_rel = torch.cat([sub_rel, sub_rel + num_rels]) - - sub_g = dgl.graph((sub_src, sub_dst), num_nodes=g.num_nodes()) - sub_g.edata[dgl.ETYPE] = sub_rel - return sub_g -def disjoint_split_hetero(graph, train_ratio=0.5, val_ratio=0.4, test_ratio=0.1): - """ - Create a disjoint split of the heterogeneous graph into training, validation, and test sets. - - Args: - graph (dgl.DGLHeteroGraph): The input graph. - train_ratio (float): Ratio of edges for the training set. - val_ratio (float): Ratio of edges for the validation set. - test_ratio (float): Ratio of edges for the test set. - - Returns: - dict: A dictionary containing the disjoint sets of edges for each edge type. - """ - num_nodes_dict={ntype: g.num_nodes(ntype) for ntype in g.ntypes} - # Shuffle the edges of each edge type - train_edges, val_edges, test_edges = {}, {}, {} - - for etype in g.canonical_etypes: - src, dst = g.edges(etype=etype) - num_edges = g.num_edges(etype=etype) - - edges_df = pd.DataFrame({'src': src.numpy(), 'dst': dst.numpy()}) - train_edges_df = edges_df.sample(frac=mask['train']) - val_test_edges_df = edges_df.drop(train_edges_df.index) - val_edges_df = val_test_edges_df.sample(frac=mask['val'] / (1 - mask['train'])) - test_edges_df = val_test_edges_df.drop(val_edges_df.index) - - train_edges[etype] = (torch.tensor(train_edges_df['src'].tolist()), torch.tensor(train_edges_df['dst'].tolist())) - val_edges[etype] = (torch.tensor(val_edges_df['src'].tolist()), torch.tensor(val_edges_df['dst'].tolist())) - test_edges[etype] = (torch.tensor(test_edges_df['src'].tolist()), torch.tensor(test_edges_df['dst'].tolist())) - - train_g = dgl.heterograph(train_edges,num_nodes_dict=num_nodes_dict) - val_g = dgl.heterograph(val_edges,num_nodes_dict=num_nodes_dict) - test_g = dgl.heterograph(test_edges,num_nodes_dict=num_nodes_dict) - - return {'train': train_g, 'val': val_g, 'test': test_g} -def get_subsets(g,mask, bidirected=False): - """ - Generate training, validation, and test subsets of a graph. - - Args: - g (dgl.DGLHeteroGraph): Original graph. - mask (dict): Dictionary containing train, val, and test fractions. - bidirected (bool): Flag indicating whether the graph is bidirected. - - Returns: - dgl.DGLHeteroGraph: Training subset. - dgl.DGLHeteroGraph: Validation subset. - dgl.DGLHeteroGraph: Test subset. - """ - train_edges=dict() - val_edges=dict() - test_edges=dict() - for ss,etype,dd in g.canonical_etypes: - num_edges = g.num_edges(etype) - src, dst = g.edges(etype=etype) - edges = pd.DataFrame() - edges['src'] = src.squeeze() - edges['dst'] = dst.squeeze() - - train_edg = edges.sample(frac=mask['train']) - #print(train_edg) - train_edges[(ss, etype, dd)] = (torch.tensor(train_edg['src'].to_list()), torch.tensor(train_edg['dst'].to_list())) - #print(edges[~edges.isin(train_edg).all(axis=1)]) - val_test_edg = pd.merge(edges,train_edg, indicator=True, how='outer').query('_merge=="left_only"').drop('_merge', axis=1) - #val_test_edg = edges[~edges.isin(train_edg).all(axis=1)] - val_edg = val_test_edg.sample(frac=mask['val'] / (1 - mask['train'])) - test_edg = pd.merge(val_test_edg,val_edg, indicator=True, how='outer').query('_merge=="left_only"').drop('_merge', axis=1) - #test_edg = val_test_edg[~val_test_edg.isin(val_edg).all(axis=1)] - val_edges[(ss, etype, dd)] = (torch.tensor(val_edg['src'].to_list()), torch.tensor(val_edg['dst'].to_list())) - test_edges[(ss, etype, dd)] = (torch.tensor(test_edg['src'].to_list()), torch.tensor(test_edg['dst'].to_list())) - - - - if bidirected: - sub_src, sub_dst = torch.cat([sub_src, sub_dst]), torch.cat( - [sub_dst, sub_src] - ) - sub_rel = torch.cat([sub_rel, sub_rel + num_rels]) - nodes_dict={ntype: g.num_nodes(ntype) for ntype in g.ntypes} - #print(train_edges) - train_g = dgl.heterograph(train_edges, num_nodes_dict=nodes_dict) - val_g = dgl.heterograph(val_edges, num_nodes_dict=nodes_dict) - test_g = dgl.heterograph(test_edges, num_nodes_dict=nodes_dict) - train_g.ndata['h']=g.ndata['h'] - val_g.ndata['h']=g.ndata['h'] - test_g.ndata['h']=g.ndata['h'] - return train_g, val_g, test_g -def get_subsets_s(g,mask, bidirected=False): - """ - Generate training, validation, and test subsets of a single-graph. - - Args: - g (dgl.DGLGraph): Original graph. - mask (dict): Dictionary containing train, val, and test fractions. - bidirected (bool): Flag indicating whether the graph is bidirected. - - Returns: - dgl.DGLGraph: Training subset. - dgl.DGLGraph: Validation subset. - dgl.DGLGraph: Test subset. - """ - num_edges = g.num_edges() - src, dst = g.edges() - edges = pd.DataFrame() - edges['src'] = src.squeeze() - edges['dst'] = dst.squeeze() - #print(len(src)*mask['train']) - train_edg = edges.sample(frac=mask['train']) if len(src)*mask['train'] <= len(src)-2 else edges.sample(len(src)-2) - #print(train_edg) - train_edges = (torch.tensor(train_edg['src'].to_list()), torch.tensor(train_edg['dst'].to_list())) - #print(edges[~edges.isin(train_edg).all(axis=1)]) - merged = pd.merge(edges,train_edg, how='outer', indicator=True) - val_test_edg = merged[merged['_merge'] == 'left_only'].drop('_merge', axis=1) - #val_test_edg = edges[~edges.isin(train_edg).all(axis=1)] - val_edg = val_test_edg.sample(frac=mask['val'] / (1 - mask['train'])) if len(val_test_edg)*(mask['val'] / (1 - mask['train'])) > 1 and len(val_test_edg)*(mask['val'] / (1 - mask['train'])) < len(val_test_edg)-1 else val_test_edg.sample(1) if len(val_test_edg)*(mask['val'] / (1 - mask['train'])) > 1 else val_test_edg.sample(len(val_test_edg)-1) - merged = pd.merge(val_test_edg,val_edg, how='outer', indicator=True) - test_edg = merged[merged['_merge'] == 'left_only'].drop('_merge', axis=1) - """ - tmp = val_edg.sample(1) if test_edg.empty else test_edg - if(test_edg.empty): - merged = pd.merge(val_edg,tmp, how='outer', indicator=True) - val_edg = merged[merged['_merge'] == 'left_only'].drop('_merge', axis=1) - test_edg = tmp if test_edg.empty else test_edg - """ - #test_edg = val_test_edg[~val_test_edg.isin(val_edg).all(axis=1)] - val_edges = (torch.tensor(val_edg['src'].to_list()), torch.tensor(val_edg['dst'].to_list())) - test_edges = (torch.tensor(test_edg['src'].to_list()), torch.tensor(test_edg['dst'].to_list())) - #print(val_test_edg) - #print("ttest") - #print(test_edges) - - - if bidirected: - sub_src, sub_dst = torch.cat([sub_src, sub_dst]), torch.cat( - [sub_dst, sub_src] - ) - sub_rel = torch.cat([sub_rel, sub_rel + num_rels]) - #nodes_dict={ntype: g.num_nodes(ntype) for ntype in g.ntypes} - #print(train_edges) - train_g = dgl.graph(train_edges, num_nodes=g.num_nodes()) - val_g = dgl.graph(val_edges, num_nodes=g.num_nodes()) - test_g = dgl.graph(test_edges, num_nodes=g.num_nodes()) - train_g.ndata['h']=g.ndata['h'] - train_g.ndata['feat']=g.ndata['feat'] - val_g.ndata['h']=g.ndata['h'] - val_g.ndata['feat']=g.ndata['feat'] - test_g.ndata['h']=g.ndata['h'] - test_g.ndata['feat']=g.ndata['feat'] - return train_g, val_g, test_g -class GlobalUniform: - """ - Global uniform sampler for graph edges. - - Args: - g (dgl.DGLGraph): Original graph. - sample_size (int): Size of the sample. - - Returns: - dict: Sampled edges. - """ - def __init__(self, g, sample_size): - self.sample_size = sample_size - self.eids = {etype: np.arange(g.num_edges(etype)) for src,etype,_ in g.canonical_etypes} - self.g=g - def sample(self): - return {etype: torch.from_numpy(np.random.choice(self.eids[etype], self.sample_size)) for src,etype,_ in self.g.canonical_etypes} - -class NegativeSamplerHet(object): - """ - Negative sampler for heterogeneous graphs. - - Args: - g (dgl.DGLHeteroGraph): Original graph. - k (int): Number of negative samples per positive sample. - neg_share (bool): Flag indicating whether to share negative samples between relations. - - Returns: - dict: Negative samples for each edge type. - """ - def __init__(self, g, k, neg_share=False): - self.src_weights ={etype: torch.tensor(np.arange(0,g.number_of_nodes(src),dtype=float)) for src,etype,_ in g.canonical_etypes} - self.dst_weights={etype: torch.tensor(np.arange(0,g.number_of_nodes(dst),dtype=float)) for _,etype,dst in g.canonical_etypes} - self.k = k - self.neg_share = neg_share - - def __call__(self, g, eids_dict,mode="src_dst_corr"): - """ - Generate negative samples for given positive samples. - - Args: - g (dgl.DGLHeteroGraph): Original graph. - eids_dict (dict): Dictionary of positive samples for each edge type. - mode (str): Sampling mode, one of "src_dst_corr", "dst_corr", or "random". - - Returns: - dict: Negative samples for each edge type. - """ - result_dict=dict() - for etype, eids in eids_dict.items(): - src, dst = g.find_edges(eids, etype=etype) - maximum= g.nodes(g.to_canonical_etype(etype)[0]).size(dim=0)* g.nodes(g.to_canonical_etype(etype)[2]).size(dim=0) - src, dst = g.edges(etype=etype) - #print(maximum) - #print("this is the point") - #print(self.src_weights) - #print(etype) - #print("this is post point") - n = len(src) - nximum=maximum - n - #print(nximum) - #print(n) - #print(maximum) - if(maximum - n < n): - n=(maximum - n) - if(n<=0): - n=len(src) - # Randomly sample src and dst - src, dst = src.numpy(), dst.numpy() - num_samples = n * self.k - #print(num_samples) - positive_edges = set(zip(src.tolist(), dst.tolist())) - #print(len(src) / maximum) - # Initialize containers for negative samples - neg_src = [] - neg_dst = [] - - if( len(src) / maximum > 0.5 or nximum < 3): - positive_edges = set(zip(src.tolist(), dst.tolist())) - if(len(src) > 1000 or nximum < 3): - #print("or enters this other") - user_nodes=range(g.num_nodes(g.to_canonical_etype(etype)[0])) - i=0 - if(nximum<=0): - - # Randomly select a user and an item - src_smp = random.sample(range(g.num_nodes(g.to_canonical_etype(etype)[0])),3) - dst_smp = random.sample(range(g.num_nodes(g.to_canonical_etype(etype)[2])),3) - - # Check if the selected user-item pair forms an existing edge of the specified type - - # Add the pair to negative samples - neg_src.append(src_smp) - neg_dst.append(dst_smp) - - else: - #print("enters this thing") - dst_num=set(range(g.num_nodes(g.to_canonical_etype(etype)[2]))) - indexed_tuples = defaultdict(list) - for idx, (key, value) in enumerate(positive_edges): - indexed_tuples[key].append(value) - left=set(range(g.num_nodes(g.to_canonical_etype(etype)[0])))-set(indexed_tuples[key].keys()) - for val in left: - indexed_tuples[val] - node_index=dict(sorted(indexed_tuples.items(), key=lambda item: len(item[1]))) - valid_combinations=list() - for (k,u) in node_index.items(): - # Randomly select a user and an item - node_index[k]=list(dst_num-set(u)) if len(dst_num) > len(u) else [] - #list_of_targets=random.sample(range(sizz),num_samples) - - node_index={k: u for (k,u) in node_index.items() if u != []} - for (k,u) in node_index.items(): - - valid_combinations+=list(zip([k]*len(u),u)) - valid_combinations=random.sample(valid_combinations,num_samples) if len(valid_combinations) > num_samples*2 else valid_combinations - - if len(valid_combinations) < 3: - valid_combinations+=list(random.sample(valid_combinations,3-len(valid_combinations))) - for index in valid_combinations: - neg_src.append(index[0]) - neg_dst.append(index[1]) - - del valid_combinations - del node_index - - - - - - else: - #print("alcanza esto") - valid_combinations=[] - - #print("llega aqui?") - filtered_sorted_tuples = sorted(positive_edges, key=lambda x: x[0]) - - # Index tuples with the same first element together - indexed_tuples = defaultdict(list) - for idx, (key, value) in enumerate(filtered_sorted_tuples): - indexed_tuples[key].append(value) - for i in range(g.num_nodes(g.to_canonical_etype(etype)[0])): - num_nodelist=range(g.num_nodes(g.to_canonical_etype(etype)[2])) - dst=list(set(num_nodelist)-set(indexed_tuples[i])) if len(indexed_tuples[i]) < len(num_nodelist) else [] - valid_combinations+=[(i,aa) for aa in dst] if len(indexed_tuples[i]) < len(num_nodelist) else [] - if len(valid_combinations) > num_samples: - sampled_combinations = np.random.choice(range(len(valid_combinations)), size=num_samples, replace=False) - valid_combinations=[valid_combinations[index] for index in sampled_combinations] - #print("alcanza esto") - - - # Randomly select from the valid combinations up to the required number of samples - if len(valid_combinations) >= num_samples: - sampled_combinations = np.random.choice(range(len(valid_combinations)), size=num_samples, replace=False) - for index in sampled_combinations: - neg_src.append(valid_combinations[index][0]) - neg_dst.append(valid_combinations[index][1]) - else: - #print("and this") - while len(neg_src) < num_samples: - if mode == "src_dst_corr": - # Sample with replacement - sampled_src = self.src_weights[etype].multinomial(num_samples, replacement=True).numpy() - sampled_dst = self.dst_weights[etype].multinomial(num_samples, replacement=True).numpy() - elif mode == "dst_corr": - sampled_src = np.random.choice(src, size=num_samples, replace=True) - sampled_dst = self.dst_weights[etype].multinomial(num_samples, replacement=True).numpy() - else: - sampled_src = self.src_weights[etype].multinomial(num_samples, replacement=True).numpy() - sampled_dst = np.random.choice(dst, size=num_samples, replace=True) - # Convert to set for faster lookup and filter out positive edges - #print(len(neg_src)) - #print(sampled_src) - #print(num_samples) - sampled_edges = set(zip(sampled_src, sampled_dst)) - positive_edges - - # Filter out self-loops and already chosen negative edges - sampled_edges = sampled_edges - set(zip(neg_src, neg_dst)) - # Add the newly sampled and filtered edges to the negative samples - for s, d in sampled_edges: - if len(neg_src) < num_samples: - neg_src.append(s) - neg_dst.append(d) - #print("reaches this") - while len(neg_src) < 3: - if mode == "src_dst_corr": - # Sample with replacement - - sampled_src = self.src_weights[etype].multinomial(num_samples, replacement=True).numpy() - sampled_dst = self.dst_weights[etype].multinomial(num_samples, replacement=True).numpy() - elif mode == "dst_corr": - sampled_src = np.random.choice(src, size=num_samples, replace=True) - sampled_dst = self.dst_weights[etype].multinomial(num_samples, replacement=True).numpy() - else: - sampled_src = self.src_weights[etype].multinomial(num_samples, replacement=True).numpy() - sampled_dst = np.random.choice(dst, size=num_samples, replace=True) - # Convert to set for faster lookup and filter out positive edges - sampled_edges = set(zip(sampled_src, sampled_dst)) - sampled_edges = {(s, d) for s, d in sampled_edges if (s, d) not in zip(neg_src, neg_dst)} - for s, d in sampled_edges: - - neg_src.append(s) - neg_dst.append(d) - - # Convert negative samples to tensors - neg_src = torch.tensor(neg_src, dtype=torch.long) - neg_dst = torch.tensor(neg_dst, dtype=torch.long) - - - # Randomly sample src and dst - - # Exclude values already present in the original edges - - - - # Ensure we have enough samples - - result_dict[g.to_canonical_etype(etype)] = (neg_src[:n * self.k], neg_dst[:n * self.k]) - #print(result_dict.keys()) - return result_dict -class NegativeSamplerSim(object): - """ - Negative sampler for homogeneous graphs. - - Args: - g (dgl.DGLGraph): Original graph. - k (int): Number of negative samples per positive sample. - neg_share (bool): Flag indicating whether to share negative samples. - - Returns: - tuple: Negative samples for source and destination nodes. - """ - def __init__(self, g, k, neg_share=False): - self.src_weights =torch.tensor(np.linspace(0.0,g.num_nodes(),num=g.num_nodes())) - self.dst_weights= torch.tensor(np.linspace(0.0,g.num_nodes(),num=g.num_nodes())) - self.k = k - self.neg_share = neg_share - - def __call__(self, g,mode="src_dst_corr"): - """ - Generate negative samples for given positive samples. - - Args: - g (dgl.DGLGraph): Original graph. - mode (str): Sampling mode, one of "src_dst_corr", "dst_corr", or "random". - - Returns: - tuple: Negative samples for source and destination nodes. - """ - result_dict=dict() - - src, dst = g.edges() - maximum= g.num_nodes()* g.num_nodes() - n = len(src) - nximum=maximum - n - #print(n) - #print(maximum) - if(maximum - n < n*self.k): - n=(maximum - n)/self.k - if(n<=0): - n=len(src) - # Randomly sample src and dst - if(mode=="src_dst_corr"): - srcR = self.src_weights.multinomial(n * self.k, replacement=True) - dstR = self.dst_weights.multinomial(n * self.k, replacement=True) - elif(mode=="dst_corr"): - if( len(src) < n*self.k): - srcR= src.repeat_interleave(k) - if(srcR.size(dim=0) > n*self.k): - srcR=src.multinomial(n*self.k) - dstR = self.dst_weights.multinomial(n * self.k, replacement=True) - else: - srcR = self.src_weights.multinomial(n * self.k, replacement=True) - if( len(src) < n*self.k): - dstR= dst.repeat_interleave(k) - if(dstR.size(dim=0) > n*self.k): - srcR=src.multinomial(n*self.k) - # Exclude values already present in the original edges - - srcR_dstR=torch.cat((srcR.unsqueeze(1),dstR.unsqueeze(1)),dim=1) - src_dst=torch.cat((src.unsqueeze(1),dst.unsqueeze(1)),dim=1) - #print(srcR_dstR) - srcR_dstR=torch.tensor(list(set(map(tuple,srcR_dstR.tolist())) - set(map(tuple,src_dst.tolist())))) - #print(srcR_dstR) - srcR,dstR = srcR_dstR[:, 0] if len(srcR_dstR) > 0 else torch.tensor([]), srcR_dstR[:, 1] if len(srcR_dstR) > 0 else torch.tensor([]) - # Ensure we have enough samples - count=0 - while len(srcR) < n * self.k and (count < 1000 or (len(srcR)<3 and nximum>3)): - if(mode=="src_dst_corr"): - additional_src = self.src_weights.multinomial(n * self.k - len(srcR), replacement=True) - additional_dst = self.dst_weights.multinomial(n * self.k - len(dstR), replacement=True) - elif(mode=="dst_corr"): - additional_dst = self.dst_weights.multinomial(n * self.k - len(dstR), replacement=True) - additional_src = src.multinomial(n*self.k-len(srcR)) - else: - additional_src = self.src_weights.multinomial(n * self.k - len(srcR), replacement=True) - additional_dst = dst.multinomial(n*self.k-len(dstR)) - add_src_add_dst=torch.cat((additional_src.unsqueeze(1),additional_dst.unsqueeze(1)),dim=1) - srcR_dstR=torch.cat((srcR.unsqueeze(1),dstR.unsqueeze(1)),dim=1) - src_dst=torch.cat((src.unsqueeze(1),dst.unsqueeze(1)),dim=1) - srcR_dstR=torch.tensor(list(set(map(tuple,srcR_dstR.tolist()+ add_src_add_dst.tolist())) - set(map(tuple,src_dst.tolist())))) - srcR,dstR = srcR_dstR[:, 0] if len(srcR_dstR) > 0 else torch.tensor([]), srcR_dstR[:, 1] if len(srcR_dstR) > 0 else torch.tensor([]) - count+=1 - while(len(srcR) < 3): - if(mode=="src_dst_corr"): - additional_src = self.src_weights.multinomial(3 - len(srcR), replacement=True) - additional_dst = self.dst_weights.multinomial(3 - len(dstR), replacement=True) - elif(mode=="dst_corr"): - additional_dst = self.dst_weights.multinomial(3 - len(dstR), replacement=True) - additional_src = src.multinomial(3-len(srcR)) - else: - additional_src = self.src_weights.multinomial(3 - len(srcR), replacement=True) - additional_dst = dst.multinomial(3-len(dstR)) - add_src_add_dst=torch.cat((additional_src.unsqueeze(1),additional_dst.unsqueeze(1)),dim=1) - srcR_dstR=torch.cat((srcR.unsqueeze(1),dstR.unsqueeze(1)),dim=1) - src_dst=torch.cat((src.unsqueeze(1),dst.unsqueeze(1)),dim=1) - srcR_dstR=torch.tensor(list(set(map(tuple,srcR_dstR.tolist()+ add_src_add_dst.tolist())))) - srcR,dstR = srcR_dstR[:, 0] if len(srcR_dstR) > 0 else torch.tensor([]), srcR_dstR[:, 1] if len(srcR_dstR) > 0 else torch.tensor([]) - result_dict = (srcR[:n * self.k], dstR[:n * self.k]) - neg_g=dgl.graph(result_dict,num_nodes=g.num_nodes()) - #print(neg_g.num_edges()) - #print(nximum) - neg_g.ndata['h']=g.ndata['h'] - neg_g.ndata['feat']=g.ndata['feat'] - return neg_g -class NegativeSampler: - """ - Generic negative sampler for homogeneous graphs. - - Args: - k (int): Number of negative samples per positive sample. - - Returns: - tuple: Negative samples and labels. - """ - def __init__(self, k=10): # negative sampling rate = 10 - self.k = k - - def sample(self, pos_samples, num_nodes): - """ - Generate negative samples for given positive samples. - - Args: - pos_samples (numpy.ndarray): Positive samples. - num_nodes (list): List containing the number of nodes for each node type. - - Returns: - tuple: Negative samples and labels. - """ - batch_size = len(pos_samples) - neg_batch_size = batch_size * self.k - neg_samples = np.tile(pos_samples, (self.k, 1)) - - values = np.random.randint(num_nodes[0], size=neg_batch_size) - values2 = np.random.randint(num_nodes[1], size=neg_batch_size) - choices = np.random.uniform(size=neg_batch_size) - subj = choices > 0.5 - obj = choices <= 0.5 - neg_samples[subj, 0] = values[subj] - neg_samples[obj, 1] = values2[obj] - samples = np.concatenate(pos_samples) - - # binary labels indicating positive and negative samples - labels = np.zeros(batch_size * (self.k + 1), dtype=np.float32) - labels[:batch_size] = 1 - - return torch.from_numpy(samples), torch.from_numpy(labels) - -class SubgraphIterator: - """ - Subgraph iterator for generating positive and negative samples. - - Args: - g (dgl.DGLHeteroGraph): Original heterogeneous graph. - num_rels (int): Number of unique relations in the graph. - sample_size (int): Number of positive samples to generate. - num_epochs (int): Number of iterations (epochs) for sampling. - - Returns: - tuple: Subgraph, unique node IDs, positive samples, and labels. - """ - def __init__(self, g, num_rels, sample_size=30000, num_epochs=6000): - self.g = g - #print(g) - self.num_rels = num_rels - self.sample_size = sample_size - self.num_epochs = num_epochs - self.pos_sampler = GlobalUniform(g, sample_size) - self.neg_sampler = NegativeSampler() - - def __len__(self): - return self.num_epochs - - def __getitem__(self, i): - val=dict() - val_edges=dict() - rela=dict() - samples=dict() - labels=dict() - uniq_v=dict() - num_nodes=dict() - weights=dict() - val_weights=dict() - eidstot = self.pos_sampler.sample() - num_nodes={ntype: self.g.num_nodes(ntype) for ntype in self.g.ntypes} - for srcn,etype,dstn in self.g.canonical_etypes: - eids = eidstot[etype] - src, dst = self.g.edges(etype=etype) - src, dst = src.numpy(), dst.numpy() - #rel = self.g.edges[etype].data[eids].numpy() - # relabel nodes to have consecutive node IDs - uniq_val, edges = np.unique((src, dst), return_inverse=True) - uniq_v[etype]=uniq_val - #num_nodes = len(uniq_v) - # edges is the concatenation of src, dst with relabeled ID - #src, dst = np.reshape(edges, (2, -1)) - relabeled_data = np.stack((src, dst)).transpose() - - #sample, label = self.neg_sampler.sample(relabeled_data, (num_nodes[srcn],num_nodes[dstn] )) - - # use only half of the positive edges - chosen_ids = np.random.choice( - np.arange(len(src)), - size=int(len(src) * 0.8), - replace=False, - ) - val_chosen_ids=torch.tensor(list(set(np.arange(len(src)).tolist())-set(chosen_ids.tolist()))) - src_temp = src[chosen_ids] - dst_temp = dst[chosen_ids] - weights[etype]=self.g.edata['w'][(srcn,etype,dstn)].numpy()[chosen_ids] - val_weights[etype]=self.g.edata['w'][(srcn,etype,dstn)].numpy()[val_chosen_ids] - samples[etype]=relabeled_data - labels[etype]=np.zeros(len(src_temp), dtype=np.float32) - labels[etype][:]=1 - labels[etype]=torch.from_numpy(labels[etype]) - src_val = src[val_chosen_ids] - dst_val = dst[val_chosen_ids] - #rel = rel[chosen_ids] - #src, dst = np.concatenate((src, dst)), np.concatenate((dst, src)) - #rel = np.concatenate((rel, rel + self.num_rels)) - #rela[etype]=rel - val[(srcn,etype,dstn)]=(src_temp,dst_temp) - val_edges[(srcn,etype,dstn)]=(src_val,dst_val) - sub_g = dgl.heterograph(val, num_nodes_dict=num_nodes) - val_g = dgl.heterograph(val_edges,num_nodes_dict=num_nodes) - for src,ETYPE,dsttype in self.g.canonical_etypes: - #sub_g.edge[ETYPE].data = torch.from_numpy(rela[ETYPE]) - dst_degs = sub_g.in_degrees(self.g.nodes(dsttype), (src,ETYPE,dsttype)).clamp(min=1).float() - sub_g.nodes[dsttype].data['norm'] = 1. / dst_degs - sub_g.apply_edges(lambda edges: {'norm': edges.dst['norm']}, etype=(src,ETYPE,dsttype)) - val_dst_deg= val_g.in_degrees(self.g.nodes(dsttype), (src,ETYPE,dsttype)).clamp(min=1).float() - val_g.nodes[dsttype].data['norm'] = 1. / val_dst_deg - val_g.apply_edges(lambda edges: {'norm': edges.dst['norm']}, etype=(src,ETYPE,dsttype)) - weights[ETYPE]=torch.from_numpy(weights[ETYPE]) - val_weights[ETYPE]= torch.from_numpy(val_weights[ETYPE]) - #sub_g.edge[ETYPE].data["norm"] = dgl.norm_by_dst(sub_g).unsqueeze(-1) - sub_g.edata['w']= weights - uniq_v={ntype: torch.tensor([i for i in range(0, sub_g.num_nodes(ntype))]) for ntype in sub_g.ntypes} - val_g.edata['w'] = val_weights - return sub_g , val_g, uniq_v, samples, labels - -class HeteroGSLayer(nn.Module): - """ - Heterogeneous GraphSAGE layer. - - Args: - G (dgl.DGLHeteroGraph): Input heterogeneous graph. - hidden_channels (int): Number of input features. - out_channels (int): Number of output features. - nonLin_mode (str): Non-linearity mode, either "sigm" or "relu". - norm_mode (bool): Flag indicating whether to normalize the output. - aggr_mode (str): Aggregation mode, one of "mean", "sum", "rgcn", or "lstm". - mid_norm (bool): Flag indicating whether to normalize intermediate results. - mid_activation (bool): Flag indicating whether to apply non-linearity to intermediate results. - dropout (float): Dropout rate. - - Returns: - torch.Tensor: Output node features. - """ - def non_linearity(self, data, mode): - if mode == "sigm": - return torch.sigmoid(data) - elif mode == "relu": - return torch.relu(data) - - def message(self, edges): - #print(edges.src['h'].shape) - return {'m': edges.src['h'], 'e': torch.tensor([self.idxEdge[edges._etype[1]] for _ in edges.src['h']])} - - def reduce(self, nodes): - etyp = self.e2idx[nodes.mailbox['e'][0, 0].item()] - #print(nodes.mailbox['m'].shape) - agg = self.aggr(nodes.mailbox['m'], etyp, self.agg_mode) - if self.agg_mode == "rgcn": - #print(agg.shape) - #print(self.lin[etyp].shape) - agg = torch.matmul(agg, self.lin[etyp]) - - else: - agg = self.lin[etyp](torch.cat([agg, nodes.data['h']], dim=-1)) - - if self.mid_activation: - agg = self.non_linearity(agg, self.nonLin_mode) - if self.mid_norm: - agg = nn.norm(agg) - - return {'h': agg} - - def reset_parameters(self, etype): - gain = nn.init.calculate_gain('relu') - if self.agg_mode == 'pool': - nn.init.xavier_uniform_(self.pool[etype].weight, gain=gain) - if self.agg_mode == 'lstm': - self.lstm[etype].reset_parameters() - if self.agg_mode != 'rgcn': - nn.init.xavier_uniform_(self.lin[etype].weight, gain=gain) - else: - nn.init.xavier_uniform_(self.weight, gain=gain) - if self.num_bases < len(self.G.etypes): - nn.init.xavier_uniform_(self.w_comp,gain=gain) - - def aggr(self, T, etype, ag): - if ag == "mean": - return torch.mean(T, 1) - elif ag == "sum": - return torch.sum(T, 1) - elif ag == "rgcn": - return torch.sum(T, 1) - elif ag == "lstm": - batch_size = T.shape[0] - h = (T.new_zeros((1, batch_size, self.in_feat)), - T.new_zeros((1, batch_size, self.in_feat))) - _, (rst, _) = self.lstm[etype](torch.swapaxes(T,0,1), h) - return rst.squeeze(dim=0) - - def __init__(self, G, hidden_channels, out_channels, nonLin_mode="sigm", norm_mode=False, aggr_mode="mean", - mid_norm=False, mid_activation=False, dropout=0.0): - super(HeteroGSLayer, self).__init__() - self.agg_mode = aggr_mode - self.nonLin_mode = nonLin_mode - self.lin = nn.ModuleDict() - self.lstm = nn.ModuleDict() - self.drop = nn.Dropout(dropout) - self.mid_activation = mid_activation - self.norm_mode = norm_mode - self.mid_norm = mid_norm - self.in_feat = hidden_channels - self.out_feat = out_channels - self.G=G - self.num_bases = len(G.etypes) - self.idxEdge = dict(map(lambda i, j: (i, j), G.etypes, range(0, len(G.etypes)))) - self.e2idx = dict(map(lambda i, j: (i, j), range(0, len(G.etypes)), G.etypes)) - if aggr_mode == 'rgcn': - self.weight = nn.Parameter(torch.Tensor(self.num_bases, self.in_feat, self.out_feat)) - - if self.num_bases < len(G.etypes): - self.w_comp = nn.Parameter(torch.Tensor(self.num_rels, self.num_bases)) - weight = self.weight.view(hidden_channels, self.num_bases, out_channels) - self.weight = torch.matmul(self.w_comp, weight).view(len(G.etypes), self.in_feat, self.out_feat) - - self.lin = nn.ParameterDict() - self.root = nn.Linear(hidden_channels, out_channels) - self.mid_activation = False - self.mid_norm = False - - - for _, etyp, _ in G.canonical_etypes: - if aggr_mode == "lstm": - self.lstm[etyp] = nn.LSTM(hidden_channels, hidden_channels) - - if aggr_mode != "rgcn": - self.lin[etyp] = nn.Linear(hidden_channels * 2, out_channels) - self.reset_parameters(etyp) - - def forward(self, G): - """ - Forward pass of the Heterogeneous GraphSAGE layer. - - Args: - G (dgl.DGLHeteroGraph): Input heterogeneous graph. - - Returns: - torch.Tensor: Output node features. - """ - if(self.agg_mode == "rgcn"): - if self.num_bases < len(G.etypes): - weight = self.weight.view(hidden_channels, self.num_bases, out_channels) - self.weight = torch.matmul(self.w_comp, weight).view(len(G.etypes), self.in_feat, self.out_feat) - ww = torch.split(self.weight,1) - ww= list(map(torch.squeeze, ww)) - # Dropout on the input node features - feat_src = {ntype: self.drop(G.ndata['h'][ntype]) for ntype in G.ntypes} - G.ndata['h'] = feat_src - emb = {} - count=0 - for etyp in G.etypes: - # Handle RGCN mode - if self.agg_mode == 'rgcn': - self.lin[etyp] = ww[count] - count += 1 - # Define the message and reduce functions for message passing - emb[etyp] = (self.message, self.reduce) - # Multi-hop message passing with specified aggregation mode ("sum" in this case) - G.multi_update_all(emb, "sum") - # Handle RGCN mode: update destination node features using a linear layer - if(self.agg_mode == "rgcn"): - G.dstdata['h'] ={ntype: G.dstdata['h'][ntype]+self.root(feat_src[ntype]) for ntype in G.ntypes} - # Apply non-linearity to the updated destination node features - if self.nonLin_mode != "None": - G.dstdata['h'] = {ntype: self.non_linearity(G.dstdata['h'][ntype], self.nonLin_mode) for ntype in G.ntypes} - # Optionally normalize the output node features - if self.norm_mode: - G.dstdata['h'] = {ntype: nn.norm(G.dstdata['h'][ntype]) for ntype in G.ntypes} - return G.dstdata['h'] -class RGCN(nn.Module): - """ - Two-layer Relational Graph Convolutional Network (RGCN) module. - - Args: - G (dgl.DGLHeteroGraph): Input heterogeneous graph. - num_nodes (int): Number of nodes in the graph. - h_dim (int): Dimension of the node embeddings. - - Attributes: - emb (dgl.nn.HeteroEmbedding): Heterogeneous embedding layer. - conv1 (dglnn.HeteroGraphConv): First Heterogeneous Graph Convolution layer. - conv2 (dglnn.HeteroGraphConv): Second Heterogeneous Graph Convolution layer. - dropout (nn.Dropout): Dropout layer for regularization. - - Methods: - forward(g, nids): Forward pass of the RGCN module. - - """ - def __init__(self,G, num_nodes,in_dim, h_dim,dropout=0.5,device='cpu'): - super().__init__() - # two-layer RGCN - #embed = dgl.nn.HeteroEmbedding({ntype: G.num_nodes(ntype) for ntype in G.ntypes},h_dim) - #self.emb={ntype: nn.Embedding(G.num_nodes(ntype), len(G.ndata['feat'][ntype][0])).to(device) for ntype in G.ntypes} - #self.emb=embed - #self.emb = nn.Embedding(num_nodes, h_dim) - #[ G.edge[etype].data['w']=nn. for ] - print(max([len(G.ndata['feat'][src][0]) for src in G.ntypes])) - self.batch=[nn.ModuleDict({ntype: torch.nn.BatchNorm1d(h_dim).to(device) for ntype in G.ntypes}),nn.ModuleDict({ntype: torch.nn.BatchNorm1d(h_dim).to(device) for ntype in G.ntypes})] - self.conv1 = dglnn.HeteroGraphConv( - {etype: dglnn.SAGEConv((len(G.ndata['feat'][src][0]),len(G.ndata['feat'][dst][0])),h_dim,"mean") for (src,etype,dst) in G.canonical_etypes}, - aggregate='sum' - ) - self.conv2 = dglnn.HeteroGraphConv( - {etype: dglnn.SAGEConv(h_dim,h_dim,"mean") for etype in G.etypes}, - aggregate='sum' - ) - self.param=nn.ParameterDict({ntype: torch.nn.Parameter(torch.zeros(h_dim)).to(device) for ntype in G.ntypes}) - #gain = nn.init.calculate_gain('relu') - #for ntype in G.ntypes: - # nn.init.uniform_(self.param[ntype]) - self.dropout =nn.ModuleDict({ntype: nn.Dropout(dropout).to(device) for ntype in G.ntypes}) - #self.relu=nn.ModuleDict({ntype: torch.nn.LeakyReLU().to(device) for ntype in G.ntypes}) - - def forward(self, g, nids, w): - # Apply convolutional layers - #nids2={ntype: nids[ntype]+self.emb[ntype](torch.tensor(list(range(g.num_nodes(ntype)))).to(device)) for ntype in g.ntypes} - data2=dgl.heterograph({(src,etype,dst): g.edges(etype=etype) for (src,etype,dst) in g.canonical_etypes}, num_nodes_dict={ntype: g.num_nodes(ntype) for ntype in g.ntypes}) - data2.ndata['h']=nids - data2.edata['w']=w - kw={etype: {"edge_weight": w[(src,etype,dst)].to(torch.float)} for (src,etype,dst) in g.canonical_etypes} - h = self.conv1(data2, nids, mod_kwargs=kw) - h = forward_op(h, self.batch[0]) - #h = forward_op(h, self.relu) - h = {ntype: F.prelu(self.batch[0][ntype](h[ntype]), self.param[ntype]) for ntype in g.ntypes} - h = forward_op(h, self.dropout) - #h = {ntype: F.prelu(self.batch[0][ntype](h[ntype]), self.param[ntype]) for ntype in g.ntypes} - #h = {ntype: self.relu[ntype](self.batch[0][ntype](h[ntype])) for ntype in g.ntypes} - #h = {ntype: self.dropout[ntype](h[ntype]) for ntype in g.ntypes} - h = self.conv2(data2, h, mod_kwargs=kw) - #h = {ntype: self.dropout[ntype](h[ntype]) for ntype in g.ntypes} - #h = {ntype: self.batch[1][ntype](h[ntype]) for ntype in g.ntypes} - h = forward_op(h, self.batch[1]) - # Update node embeddings in the input graph - #for ntype in h: - # g.ndata['h'][ntype] = h[ntype] - - return h - - -class LinkPredict(nn.Module): - """ - Link Prediction module using RGCN and MLP. - - Args: - G (dgl.DGLHeteroGraph): Input heterogeneous graph. - num_nodes (int): Number of nodes in the graph. - h_dim (int): Dimension of the node embeddings. - reg_param (float): Regularization parameter. - - Attributes: - rgcn (RGCN): RGCN module for node embeddings. - reg_param (float): Regularization parameter. - W (nn.Linear): Linear layer for scoring. - w_relation (nn.Parameter): Learnable parameter for relation embeddings. - Pred (MLPPred): MLP prediction module. - - Methods: - calc_score(g, embedding, triplets, labels, etypeu=None): Calculate scores and labels. - forward(g, nids): Forward pass of the LinkPrediction module. - pred(g, emb, etype): Prediction based on embeddings. - regularization_loss(embedding, g): Compute regularization loss. - get_loss(g, neg_g, embed, triplets, neg_triplets, labels, neg_labels, etype=None): Calculate loss. - - """ - def __init__(self, G, num_nodes, in_dim=100, h_dim=500, reg_param=0.01,dropout=0.5, device='cpu'): - super().__init__() - self.rgcn = RGCN(G, num_nodes, in_dim, h_dim, dropout=dropout, device=device) - #self.reg_param = reg_param - #self.W = nn.Linear(h_dim*2, 1) - #num_rels = max([G.num_nodes(ntype) for ntype in G.ntypes]) - self.Pred = HeteroDotProductPredictor() # Initialize predictor here - #self.Pred=MLPPred(h_dim) - - def calc_score(self, g, embedding, triplets, labels, etypeu=None, device='cpu'): - """ - Calculate scores and labels. - """ - score = None - labelss = None - edgess = dict() - print(g.canonical_etypes) - for etype in g.etypes: - edgess[tuple(g.to_canonical_etype(etype))] = (triplets[etype][0], triplets[etype][1]) - score_graph = dgl.heterograph(edgess, num_nodes_dict={ntype: g.num_nodes(ntype) for ntype in g.ntypes}) - score_graph = score_graph.to(device) - score_graph.ndata['feat'] = embedding - self.Pred=self.Pred.to(device) - etypeu="dis_dru_the" - if etypeu is None: - for etype in g.etypes: - current_score = self.Pred(g, embedding, etype).squeeze() - if score is not None: - score = torch.cat((score, current_score), dim=-1) - labelss = torch.cat((labelss, labels[etype]), dim=-1) - else: - score = current_score - labelss = labels[etype] - else: - score = self.Pred(g, embedding, etypeu) - labelss = labels[etypeu] - return score, labelss - - def forward(self, g, nids, w): - """ - Forward pass of the LinkPrediction module. - """ - return self.rgcn(g, nids, w) - - def pred(self, g, emb, etype): - """ - Prediction based on embeddings. - """ - return self.Pred(g, emb, etype) - - def get_loss_contrast(self, emb,val_emb, etype=None, device='cpu'): - """ - Calculate loss. - """ - lf=torch.nn.CosineEmbeddingLoss - predict_loss = lf(emb.squeeze(), val_embed.to(device),torch.ones(emb.squeeze().shape[0])) - return predict_loss, (emb*val_emb).sum(-1) - def get_loss(self, emb, val_emb, ntype=None, device='cpu'): - """ - Calculate loss. - """ - score, label = self.calc_score(g, embed, triplets, labels, etypeu=etype, device=device) - neg_score, neg_label = self.calc_score(neg_g, embed, neg_triplets, neg_labels, etypeu=etype, device=device) - labelsss = torch.cat((torch.ones(len(score), dtype=torch.float32), torch.zeros(len(neg_score), dtype=torch.float32)), dim=-1) - predict_loss = F.binary_cross_entropy_with_logits(torch.cat((score, neg_score), dim=0).squeeze(), labelsss.to(device)) - return predict_loss, torch.cat((score, neg_score), dim=0), labelsss - -def filter( - triplets_to_filter, target_s, target_o, num_nodes, filter_o=True -): - """ - Get candidate heads or tails to score. - - Args: - triplets_to_filter (list): List of triplets to filter. - target_s (int): Target subject node. - target_o (int): Target object node. - num_nodes (list): List of number of nodes for each type. - filter_o (bool, optional): Filter object nodes. Defaults to True. - - Returns: - torch.LongTensor: List of candidate nodes. - - """ - target_s, target_o = int(target_s), int(target_o) - # Add the ground truth node first - if filter_o: - candidate_nodes = [target_o] - else: - candidate_nodes = [target_s] - if filter_o: - for e in range(num_nodes[1]): - triplet = ( - (target_s, e) - ) - # Do not consider a node if it leads to a real triplet - if triplet not in triplets_to_filter: - candidate_nodes.append(e) - else: - for e in range(num_nodes[0]): - triplet = ( - (e, target_o) - ) - # Do not consider a node if it leads to a real triplet - if triplet not in triplets_to_filter: - candidate_nodes.append(e) - return torch.LongTensor(candidate_nodes) - -def load_protein_emb(pro_dict): - """ - Load protein embeddings. - - Args: - pro_dict (dict): Dictionary mapping protein names to IDs. - - Returns: - torch.Tensor: Tensor of protein embeddings. - - """ - prot_emb_dict=export_protein_emb() - print(len(set(pro_dict.keys()))) - print(prot_emb_dict.columns) - for jj in (set(pro_dict.keys())-set(prot_emb_dict['name'])): - ll=pd.DataFrame(data=np.zeros([1,len(prot_emb_dict.columns)]),columns=prot_emb_dict.columns, index=range(1)) - print(ll) - ll['name']=jj - ll['embedding']=[[1]*(len(prot_emb_dict['embedding'].tolist()[0]))] - prot_emb_dict=pd.concat([prot_emb_dict, ll]) - # prot_emb_dict=pd.DataFrame([[1]*len(prot_emb_dict['embedding'].tolist()[0]), jj] ,columns=['embedding','name']) - prot_emb_dict['id']=prot_emb_dict.name.map(pro_dict) - prot_emb=prot_emb_dict.sort_values(by=['id']) - return torch.Tensor(prot_emb['embedding'].tolist()) -def load_drug_emb(device,mask,in_dim=100,h_dim=100,epochs=100): - """ - Load drug embeddings. - - Args: - device (torch.device): Torch device. - mask: Mask parameter (details are not provided in the code snippet). - in_dim (int, optional): Input dimension. Defaults to 100. - h_dim (int, optional): Hidden dimension. Defaults to 100. - epochs (int, optional): Number of training epochs. Defaults to 100. - - Returns: - torch.Tensor: Tensor of drug embeddings. - - """ - if(os.path.isfile("")): - drug_emb_dict=pd.from_csv("embeddings/drug_emb.tsv",sep='t') - else: - drug_emb_dict=generate_drug_embedding(device,in_dim,h_dim,mask,epochs=epochs) - drug_emb_dict['id']=drug_emb_dict.name.map(pro_dict) - drug_emb=prot_emb_dict.sort_values(by=['id']) - return torch.Tensor(drug_emb['embedding']) -def perturb_and_get_filtered_rank( - emb, w, s, o, test_size, triplets_to_filter, filter_o=True -): - - """ - Perturb subject or object in the triplets and get filtered rank. - - Args: - emb (list): List of node embeddings. - w: Relation embeddings. - s (list): List of subject nodes. - o (list): List of object nodes. - test_size (int): Size of the test set. - triplets_to_filter (list): List of triplets to filter. - filter_o (bool, optional): Filter object nodes. Defaults to True. - - Returns: - torch.LongTensor: List of ranks. - - """ - num_nodes = [emb[0].shape[0],emb[1].shape[0]] - ranks = [] - for idx in tqdm.tqdm(range(test_size), desc="Evaluate"): - target_s = s[idx] - #target_r = r[idx] - target_o = o[idx] - candidate_nodes = filter( - triplets_to_filter, - target_s, - target_o, - num_nodes, - filter_o=filter_o, - ) - if filter_o: - emb_s = emb[0][target_s] - emb_o = emb[1][candidate_nodes] - else: - emb_s = emb[0][candidate_nodes] - emb_o = emb[1][target_o] - target_idx = 0 - #emb_r = w[target_r] - emb_triplet = emb_s * emb_o - scores = torch.sigmoid(torch.sum(emb_triplet, dim=1)) - - _, indices = torch.sort(scores, descending=True) - rank = int((indices == target_idx).nonzero()) - ranks.append(rank) - return torch.LongTensor(ranks) -class HeteroDotProductPredictor(nn.Module): - def apply_edges(self,edges): - """ - Apply the MLP layers to the edges of a graph. - - Args: - edges: Edges of the graph. - - Returns: - dict: Dictionary with the 'score' key containing the MLP output. - - """ - return {'score': torch.sum((edges.src['h'] * edges.dst['h']),dim=-1).squeeze()} - def forward(self, graph, h, etype): - with graph.local_scope(): - graph.ndata['h'] = h - # Use dgl built-in u_dot_v function for computing dot product - #graph.apply_edges(fn.u_dot_v('h', 'h', 'score'), etype=etype) - graph.apply_edges(self.apply_edges,etype=etype) - # Apply sigmoid activation to the computed scores - return F.sigmoid(graph.edges[etype].data['score']) -class MLPPred(nn.Module): - """ - Multi-Layer Perceptron (MLP) for link prediction. - - Args: - n_val (int): Input dimension for the MLP. - - Attributes: - W1 (nn.Linear): First linear layer of the MLP. - W2 (nn.Linear): Second linear layer of the MLP. - - Methods: - non_lin(data, mode="sigm"): Apply non-linearity to the input data. - apply_edges(edges): Apply the MLP layers to the edges of a graph. - forward(g, h, etype): Forward pass of the MLP for link prediction. - - """ - def __init__(self,n_val): - super(MLPPred,self).__init__() - self.W1=nn.Linear(n_val*2,n_val) - self.W2=nn.Linear(n_val,1) - #self.W=w - def non_lin(self,data,mode="sigm"): - """ - Apply non-linearity to the input data. - - Args: - data (torch.Tensor): Input data. - mode (str, optional): Non-linearity mode ("sigm" or "relu"). Defaults to "sigm". - - Returns: - torch.Tensor: Transformed data. - - """ - if(mode=="sigm"): - return torch.sigmoid(data) - elif(mode=="relu"): - return torch.relu(data) - def apply_edges(self,edges): - """ - Apply the MLP layers to the edges of a graph. - - Args: - edges: Edges of the graph. - - Returns: - dict: Dictionary with the 'score' key containing the MLP output. - - """ - return {'score': self.W2(self.non_lin(self.W1(torch.cat((edges.src['h'],edges.dst['h']),dim=1)))).squeeze(1)} - def forward(self,g,h,etype): - """ - Forward pass of the MLP for link prediction. - - Args: - g (dgl.DGLHeteroGraph): Input heterogeneous graph. - h (torch.Tensor): Node embeddings. - etype (str): Edge type. - - Returns: - torch.Tensor: Predicted scores. - - """ - g.ndata['h']=h - g.apply_edges(self.apply_edges,etype=etype) - return g.edges[etype].data["score"] -def calc_auroc(pred,label, batch_size=100,etypeu=None, filter=True): - """ - Calculate Area Under the Receiver Operating Characteristic (AUROC). - - Args: - emb (torch.Tensor): Node embeddings. - w: Relation embeddings. - test_mask: Test mask parameter. - triplets_to_filter: List of triplets to filter. - g (dgl.DGLHeteroGraph): Input heterogeneous graph. - model (LinkPredict): Link prediction model. - batch_size (int, optional): Batch size. Defaults to 100. - etypeu (str, optional): Edge type for calculation. Defaults to None. - filter (bool, optional): Filter parameter. Defaults to True. - - Returns: - float: AUROC score. - - """ - - auroc = roc_auc_score(label.cpu().detach(), pred.cpu().detach()) - return auroc -def calc_mrr(pred,label, batch_size=100,etypeu=None, filter=True): - """ - Calculate Mean Reciprocal Rank (MRR). - - Args: - emb (torch.Tensor): Node embeddings. - w: Relation embeddings. - test_mask: Test mask parameter. - triplets_to_filter: List of triplets to filter. - g (dgl.DGLHeteroGraph): Input heterogeneous graph. - model (LinkPredict): Link prediction model. - batch_size (int, optional): Batch size. Defaults to 100. - etypeu (str, optional): Edge type for calculation. Defaults to None. - filter (bool, optional): Filter parameter. Defaults to True. - - Returns: - torch.Tensor: Mean Reciprocal Rank. - - """ - mrr = None - with torch.no_grad(): - for src, etype, dst in g.canonical_etypes: - test_triplets = triplets_to_filter[etype][test_mask[etype]] - s, o = test_triplets[:, 0], test_triplets[:, 1] - test_size = len(s) - triplets_to_filters = { - tuple(triplet) for triplet in triplets_to_filter[etype].tolist() - } - - ranks_s = perturb_and_get_filtered_rank( - list([emb[src], emb[dst]]), w, s, o, test_size, triplets_to_filters, filter_o=False - ) - ranks_o = perturb_and_get_filtered_rank( - list([emb[src], emb[dst]]), w, s, o, test_size, triplets_to_filters - ) - ranks = torch.cat([ranks_s, ranks_o]) - ranks += 1 # change to 1-indexed - - if mrr is not None: - mrr = torch.cat( - (mrr, torch.tensor([torch.mean(1.0 / ranks.float()).item()])), dim=0 - ) - else: - mrr = torch.tensor([torch.mean(1.0 / ranks.float()).item()]) - - return torch.mean(mrr, dim=0) -def calc_prc(pred,label, batch_size=100,etypeu=None, filter=True): - """ - Calculate Precision-Recall Curve. - - Args: - emb (torch.Tensor): Node embeddings. - w: Relation embeddings. - test_mask: Test mask parameter. - triplets_to_filter: List of triplets to filter. - g (dgl.DGLHeteroGraph): Input heterogeneous graph. - model (LinkPredict): Link prediction model. - batch_size (int, optional): Batch size. Defaults to 100. - etypeu (str, optional): Edge type for calculation. Defaults to None. - filter (bool, optional): Filter parameter. Defaults to True. - - Returns: - torch.Tensor: Precision, Recall, Average Precision. - - """ - p=pred.squeeze().cpu().detach() - pred_label = np.zeros_like(p, dtype=np.int64) - pred_label[np.where(p > 0.5)[0]] = 1 - pred_label[np.where(p <= 0.5)[0]] = 0 - precision, recall, trh= metrics.precision_recall_curve(label.cpu().detach(), pred.cpu().detach()) - average_precision = metrics.precision_score(label.cpu().detach(), torch.tensor(pred_label)) - return precision,recall,average_precision -def calc_roc(pred,label, batch_size=100,etypeu=None, filter=True,plot=False,extension=''): - """ - Calculate Receiver Operating Characteristic (ROC) Curve. - - Args: - emb (torch.Tensor): Node embeddings. - w: Relation embeddings. - test_mask: Test mask parameter. - triplets_to_filter: List of triplets to filter. - g (dgl.DGLHeteroGraph): Input heterogeneous graph. - model (LinkPredict): Link prediction model. - batch_size (int, optional): Batch size. Defaults to 100. - etypeu (str, optional): Edge type for calculation. Defaults to None. - filter (bool, optional): Filter parameter. Defaults to True. - - Returns: - torch.Tensor: False Positive Rate, True Positive Rate, ROC AUC. - - """ - - - fpr, tpr, threshold = metrics.roc_curve(label.cpu().detach(),pred.cpu().detach()) - roc_auc = metrics.auc(fpr, tpr) - if plot: - label = " ".join(edge) + " " + 'AUC = %0.2f' % roc_auc - fileName = 'metrics/aucroc' + extension + '.svg' - random = [[0, 1], [0, 1], 'r--'] - - plotAndSaveFig(title='ROC Curve', x=fpr, y=tpr, label=label, path=fileName, - loc='lower right', xlim=[0, 1], ylim=[0, 1], xlabel='False Positive Rate', - ylabel='True Positive Rate', random=random) - - return fpr,tpr,roc_auc -def final_calc_prc(pred,label, batch_size=100,etypeu=None, filter=True,plot=False,extension=''): - """ - Calculate Precision-Recall Curve. - - Args: - emb (torch.Tensor): Node embeddings. - w: Relation embeddings. - test_mask: Test mask parameter. - triplets_to_filter: List of triplets to filter. - g (dgl.DGLHeteroGraph): Input heterogeneous graph. - model (LinkPredict): Link prediction model. - batch_size (int, optional): Batch size. Defaults to 100. - etypeu (str, optional): Edge type for calculation. Defaults to None. - filter (bool, optional): Filter parameter. Defaults to True. - - Returns: - torch.Tensor: Precision, Recall, Average Precision. - - """ - p=pred.squeeze().cpu().detach() - pred_label = np.zeros_like(p, dtype=np.int64) - pred_label[np.where(p > 0.5)[0]] = 1 - pred_label[np.where(p <= 0.5)[0]] = 0 - precision, recall, trh= metrics.precision_recall_curve(label.cpu().detach(), pred.cpu().detach()) - average_precision = metrics.precision_score(label.cpu().detach(), torch.tensor(pred_label)) - print(average_precision) - if plot: - label = " ".join(edge) + " " + 'RPC = %0.2f' % average_precision - fileName = 'metrics/prc' + extension + '.svg' - - plotAndSaveFig(title='Precision-Recall Curve', x=recall, y=precision, label=label, path=fileName, - loc='lower right', xlim=[0, 1], ylim=[0, 1], xlabel='Recall', ylabel='Precision') - - - return precision,recall,average_precision -def final_calc_roc(pred,label, batch_size=100,etypeu=None, filter=True): - """ - Calculate Receiver Operating Characteristic (ROC) Curve. - - Args: - emb (torch.Tensor): Node embeddings. - w: Relation embeddings. - test_mask: Test mask parameter. - g (dgl.DGLHeteroGraph): Input heterogeneous graph. - model (LinkPredict): Link prediction model. - batch_size (int, optional): Batch size. Defaults to 100. - etypeu (str, optional): Edge type for calculation. Defaults to None. - filter (bool, optional): Filter parameter. Defaults to True. - - Returns: - torch.Tensor: False Positive Rate, True Positive Rate, ROC AUC. - - """ - - - fpr, tpr, threshold = metrics.roc_curve(label.cpu().detach(),pred.cpu().detach()) - roc_auc = metrics.auc(fpr, tpr) - return fpr,tpr,roc_auc - -def train( - dataloader, - dataloader_val, - comp_g, - test_g, - neg_test_g, - test_nids, - test_mask, - triplets, - device, - model_state_file, - model, - h_dim=100, - etype=None, - sweep=None, - w=None, - epochs=100, - learning_rate=0.01, - weight_decay=0.005144745056173074, - eps=1e-2, - betas=(0.9,0.99) -): - """ - Training function for the link prediction model. - - Args: - dataloader: DataLoader for positive training examples. - dataloader_neg: DataLoader for negative training examples. - test_g (dgl.DGLHeteroGraph): Test graph. - neg_test_g (dgl.DGLHeteroGraph): Negative test graph. - test_nids: Test node IDs. - test_mask: Test mask parameter. - triplets: List of triplets. - device: Device for training. - model_state_file: File to save the model state. - model (LinkPredict): Link prediction model. - h_dim (int, optional): Hidden dimension. Defaults to 100. - etype (str, optional): Edge type. Defaults to None. - - """ - optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate,weight_decay=weight_decay,eps=eps,betas=betas) - best_mrr = 0 - best_auroc=0 - best_roc=0 - best_avp=0 - best_model=None - avp_list=[] - roc_list=[] - model = model.to(device) - sampler=NewNegativeSampler() - comp_g.ndata['h']=test_nids - nu_nodes={ntype: comp_g.num_nodes(ntype) for ntype in comp_g.ntypes} - #gg=dgl.merge([dataloader,dataloader_val]) - #print(gg.edata) - #gg=gg.to(device) - for jj in tqdm.tqdm(range(epochs),total=epochs): - for epoch, val_batch_data in enumerate(dataloader): - model.train() - optimizer.zero_grad() - epoch=jj - #g, train_nids, edges, labels = batch_data - g,val_g,val_nids,val_edges,val_labels= val_batch_data - #g=dataloader - #val_g=dataloader_val - labels={etype: torch.ones(g.num_edges(etype)) for etype in g.etypes} - val_labels={etype: torch.ones(val_g.num_edges(etype)) for etype in val_g.etypes} - neg_eids=sampler(comp_g,{}) - neg_eids={(src,etype,dst): (neg_eids[(src,etype,dst)][0][:val_g.num_edges(etype)],neg_eids[(src,etype,dst)][1][:val_g.num_edges(etype)]) for (src,etype,dst) in comp_g.canonical_etypes} - neg_gg=dgl.heterograph(neg_eids,num_nodes_dict=nu_nodes) - neg_gg.ndata['h']=comp_g.ndata['h'] - g = g.to(device) - val_g=val_g.to(device) - #neg_val_g=neg_val_g.to(device) - #neg_g=neg_g.to(device) - #train_nids = train_nids.to(device) - #edges = edges.to(device) - #labels = labels.to(device) - test_nids={k: v.to(torch.float).to(device) for (k,v) in test_nids.items()} - print(g.edata['w']) - embed = model(g.to(device), test_nids,g.edata['w']) - - neg_gg=neg_gg.to(device) - edges={etype: val_g.edges(etype=etype) for etype in val_g.etypes} - neg_edges={etype: neg_gg.edges(etype=etype) for etype in neg_gg.etypes} - #print(device) - loss,pred,label = model.get_loss(val_g.to(device),neg_gg.to(device),embed, edges,neg_edges, labels,val_labels,etype=etype,device=device) - - loss.backward() - #nn.utils.clip_grad_norm_( - # model.parameters(), max_norm=1.0 - #) # clip gradients - optimizer.step() - p=pred.cpu().detach() - pred_label = np.zeros_like(p, dtype=np.int64) - pred_label[np.where(p > 0.5)[0]] = 1 - pred_label[np.where(p <= 0.5)[0]] = 0 - #acc = np.sum(pred_label == label.cpu().numpy()) - acc=1 - - num = len(pred_label) - - print( - "Epoch {:04d} | Loss {:.4f} | Best MRR {:.4f} | Best Auroc {:.4f} | Best PRC {:.4f} | Best ROC {:.4f} | ACC {:.4f}".format( - epoch, loss.item(), best_mrr,best_auroc,best_avp,best_roc,acc/num - ) - ) - if (epoch + 1) % 10 == 0 or (jj + 1) % 10 == 0: - # perform validation on CPU because full graph is too large - - model.eval() - gg=dgl.merge([g,val_g]) - embed = model(gg, test_nids,gg.edata['w']) - loss,pred,labels = model.get_loss(test_g.to(device),neg_test_g.to(device),embed, edges,neg_edges, labels,val_labels,etype=etype,device=device) - auroc= calc_auroc( - pred,labels, batch_size=500, etypeu=etype - ) - - if best_auroc < auroc: - best_auroc = auroc - best_model=model - #mrr = calc_mrr( - # embed, model.w_relation, test_mask, triplets,test_g,model, batch_size=500 - #) - # save best model - #if best_mrr < mrr: - # best_mrr = mrr - # torch.save( - # {"state_dict": model.state_dict(), "epoch": epoch}, - # model_state_file, - # ) - prr,recall,avp= calc_prc( - pred,labels, batch_size=500, etypeu=etype - ) - avp_list.append(avp) - if best_avp < avp: - best_avp = avp - best_model=model - fpr,tpr,auc= calc_roc( - pred,labels, batch_size=500, etypeu=etype - ) - wandb.log({"loss": loss,"precision": avp}) - roc_list.append(auc) - if best_roc < auc: - best_roc = auc - best_model=model.cpu() - model = model.to(device) - if(device=="cuda"): - torch.cuda.empty_cache() - #best_model=model - return best_model,roc_list,avp_list -def plotAndSaveFig(title, x, y, label, path, loc=None, xlim=None, ylim=None, xlabel=None, ylabel=None, random=None): - """ - Plot and save a figure. - - Args: - title (str): Title of the plot. - x (list): x-axis data. - y (list): y-axis data. - label (str): Label for the plot. - path (str): Path to save the plot. - loc (str, optional): Location of the legend. Defaults to None. - xlim (tuple, optional): Limits for the x-axis. Defaults to None. - ylim (tuple, optional): Limits for the y-axis. Defaults to None. - xlabel (str, optional): Label for the x-axis. Defaults to None. - ylabel (str, optional): Label for the y-axis. Defaults to None. - random (list, optional): Random data to plot. Defaults to None. - - """ - plt.title(title) - plt.plot(x, y, label=label) - - if loc is not None: - plt.legend(loc=loc) - - if xlim is not None: - plt.xlim(xlim) - - if ylim is not None: - plt.ylim(ylim) - - if xlabel is not None: - plt.xlabel(xlabel) - - if ylabel is not None: - plt.ylabel(ylabel) - - if random is not None: - plt.plot(random[0], random[1], random[2]) - - plt.show() - plt.savefig(path, format='svg', dpi=1200) - plt.clf() -def train_fold_with_samples(dataloader,val_dataloader,comp_g,model,pos_test_g,neg_test_g,test_nids,test_mask,triplets,device,h_dim=100,etype=None,w=None,learning_rate=0.01,weight_decay=0.005, epochs=100, eps=0.001, betas=(0.9,0.99)): - """ - Training function for a fold of cross-validation. - - Args: - dataloader: DataLoader for positive training examples. - neg_dataloader: DataLoader for negative training examples. - model (LinkPredict): Link prediction model. - pos_val_g (dgl.DGLHeteroGraph): Positive validation graph. - neg_val_g (dgl.DGLHeteroGraph): Negative validation graph. - pos_test_g (dgl.DGLHeteroGraph): Positive test graph. - neg_test_g (dgl.DGLHeteroGraph): Negative test graph. - test_nids: Test node IDs. - test_mask: Test mask parameter. - triplets: List of triplets. - device: Device for training. - h_dim (int, optional): Hidden dimension. Defaults to 100. - etype (str, optional): Edge type. Defaults to None. - - Returns: - torch.Tensor: Embeddings. - float: ROC AUC. - float: Average Precision. - - """ - model_state_file = "model_state.pth" - if train: - model.train() - optimizer.zero_grad() - # g, train_nids, edges, labels = batch_data - # g=dataloader - # val_g=dataloader_val - g = g.to(device) # neg_val_g=neg_val_g.to(device) - # neg_g=neg_g.to(device) - # train_nids = train_nids.to(device) - # edges = edges.to(device) - # labels = labels.to(device) - test_nids = {k: v.to(torch.float).to(device) for (k, v) in test_nids.items()} - print(g.edata['w']) - embed = model(g.to(device), test_nids, g.edata['w']) - edges = {etype: val_g.edges(etype=etype) for etype in val_g.etypes} - # print(device) - loss, pred = model.get_loss_contrast(embed, val_embed,ntype=ntype, device=device) - - loss.backward() - # nn.utils.clip_grad_norm_( - # model.parameters(), max_norm=1.0 - # ) # clip gradients - optimizer.step() - best_model=embed - # acc = np.sum(pred_label == label.cpu().numpy()) - acc = 1 - prr, recall, avrp = calc_prc( - pred, torch.ones(pred.shape), batch_size=500, etypeu=etype - ) - fpr, tpr, auc = calc_roc( - pred, torch.ones(pred.shape), batch_size=500, etypeu=etype - ) - else: - model.eval() - g, val_g, val_nids, val_edges, val_labels = val_batch_data - # g=dataloader - # val_g=dataloader_val - labels = {etype: torch.ones(g.num_edges(etype)) for etype in g.etypes} - val_labels = {etype: torch.ones(val_g.num_edges(etype)) for etype in val_g.etypes} - g = g.to(device) - val_g = val_g.to(device) - test_nids = {k: v.to(torch.float).to(device) for (k, v) in test_nids.items()} - print(g.edata['w']) - embed = model(g.to(device), test_nids, g.edata['w']) - loss, pred, labels = model.get_loss_contrast(test_g.to(device), neg_test_g.to(device), embed, edges, neg_edges, labels, - val_labels, etype=etype, device=device) - prr, recall, avrp = calc_prc( - pred, torch.ones(pred.shape), batch_size=500, etypeu=etype - ) - fpr, tpr, auc = calc_roc( - pred, torch.ones(pred.shape), batch_size=500, etypeu=etype - ) - wandb.log({"loss": loss, "precision": avp}) - best_model = embed - if(device=="cuda"): - torch.cuda.empty_cache() - # testing - print("Testing...") - - return best_model,auc,avrp -def find(subp_g, val_list, pos_list): - """ - Find positive and validation subgraphs. - - Args: - subp_g (dgl.DGLHeteroGraph): Subgraph. - val_list: List of validation graphs. - pos_list: List of positive graphs. - - Returns: - Tuple: Validated positive graphs, validation graphs, filtered positive indices. - - """ - pos_ffound = set(range(len(pos_list))) - pos_already_removed = 0 - - for etype in pos_list[0].etypes: - pos_match = torch.zeros(subp_g.num_edges(etype=etype)) - val_match = torch.zeros(subp_g.num_edges(etype=etype)) - src, dst = subp_g.edges(etype=etype) - - for u in range(len(subp_g.edges(etype=etype))): - # Positive subgraph matching - for g in pos_list: - ssrc, ddst = g.edges(etype=etype) - index = (src.unsqueeze(1) == ssrc).nonzero(as_tuple=False).squeeze() - indexT=torch.transpose(index,0,1) - pos_match[index[torch.eq(dst[indexT[0]],ddst[indexT[1]]).nonzero()]] = 1 - - # Validation subgraph matching - for g in val_list: - ssrc, ddst = g.edges(etype=etype) - index = (src.unsqueeze(1) == ssrc).nonzero(as_tuple=False).squeeze() - indexT=torch.transpose(index,0,1) - val_match[index[torch.eq(dst[indexT[0]], ddst[indexT[1]]).nonzero()]] = 1 - - # Positive graphs removal - if pos_match.sum() < len(src) or val_match.sum() < len(src): - ext=min(pos_match.sum(),val_match.sum()) - #print(int(len(pos_ffound) * ((len(src) - ext) / (len(src)*len(pos_list))))>0) - #print(len(pos_ffound)) - if(int(len(pos_ffound) * ((len(src) - ext) / (len(src)*len(pos_list))))>0 and len(pos_ffound)>0): - pos_ffound.difference_update(random.sample(pos_ffound,(len(pos_ffound)*((len(src) - ext) / (len(src)*len(pos_list)))))) - - return val_list, pos_list, pos_ffound - -def plotMetrics(fpr, tpr, label1, recall, precision, label2): - # Vertical plotting. - fig, axs = plt.subplots(2, figsize=(6, 10)) - - axs[0].plot(fpr, tpr, label="AUC ROC = " + np.array2string(label1, formatter={'float_kind': lambda x: "%.2f" % x})) - axs[0].set_title('ROC Curve') - axs[0].legend(loc='lower right') - axs[0].plot([0, 1], [0, 1], 'r--') - axs[0].set_xlim([0, 1]) - axs[0].set_ylim([0, 1]) - axs[0].set_ylabel('True Positive Rate') - axs[0].set_xlabel('False Positive Rate') - - axs[1].set_title('Precision-Recall Curve') - axs[1].plot(recall, precision, - label="PRC = " + np.array2string(label2, formatter={'float_kind': lambda x: "%.2f" % x})) - axs[1].legend(loc='lower right') - axs[1].set_xlim([0, 1]) - axs[1].set_ylim([0, 1]) - axs[1].set_ylabel('Precision') - axs[1].set_xlabel('Recall') - - # Horizontal plotting. - fig2, axs2 = plt.subplots(1, 2, figsize=(12, 4)) - - axs2[0].plot(fpr, tpr, label="AUC ROC = " + np.array2string(label1, formatter={'float_kind': lambda x: "%.2f" % x})) - axs2[0].set_title('ROC Curve') - axs2[0].legend(loc='lower right') - axs2[0].plot([0, 1], [0, 1], 'r--') - axs2[0].set_xlim([0, 1]) - axs2[0].set_ylim([0, 1]) - axs2[0].set_ylabel('True Positive Rate') - axs2[0].set_xlabel('False Positive Rate') - - axs2[1].set_title('Precision-Recall Curve') - axs2[1].plot(recall, precision, - label="PRC = " + np.array2string(label2, formatter={'float_kind': lambda x: "%.2f" % x})) - axs2[1].legend(loc='lower right') - axs2[1].set_xlim([0, 1]) - axs2[1].set_ylim([0, 1]) - axs2[1].set_ylabel('Precision') - axs2[1].set_xlabel('Recall') - - fig.savefig('metrics/aucroc&prcRepoDBVertical.svg', format='svg', dpi=1200) - fig2.savefig('metrics/aucroc&prcRepoDBHorizontal.svg', format='svg', dpi=1200) - plt.close(fig) - plt.close(fig2) - -def random_fold_loop(g, mask, k,etype,ss,dd, max_iterations=100): - """ - Randomly generate validation and test sets for k-fold cross-validation ensuring all edges are in at least one test and validation set of a fold. - - Args: - g (dgl.DGLHeteroGraph): Original graph. - mask: Mask parameter. - k (int): Number of folds. - etype: Edge type. - ss: Source node type. - dd: Destination node type. - max_iterations (int, optional): Maximum iterations for random sampling. Defaults to 100. - - Returns: - Tuple: Validated positive graphs, validation graphs, training graphs, iteration. - - """ - train_list, val_list, test_list = [dict() for _ in range(k)], [dict() for _ in range(k)], [dict() for _ in range(k)] - test_edg, val_edg = [None for _ in range(k)], [None for _ in range(k)] - test_left_mask, val_left_mask = [None for _ in range(k)], [0.0 for _ in range(k)] - num_edges = g.num_edges(etype) - src, dst = g.edges(etype=etype) - edges = pd.DataFrame({'src': src.squeeze(), 'dst': dst.squeeze()}).drop_duplicates() - num_edges=edges.shape[0] - val_edges, test_edges = edges.copy(), edges.copy() - test_rest, val_rest=None, None - continue_loop= False - i=0 - - prob, iteration = 1 / k, 0 - continue_loop=True - test_presamples=test_edges.sample(frac=1).reset_index(drop=True) - val_presamples=val_edges.sample(frac=1).reset_index(drop=True) - if(test_presamples.iloc[0:1,:].equals(val_presamples.iloc[0:1,:])): - qq=random.randint(1,len(val_presamples)) - new_u=val_presamples.iloc[qq:qq+1,:] - val_presamples.iloc[qq:qq+1,:]=val_presamples.iloc[0:1,:] - val_presamples.iloc[0:1,:]=new_u - already_sampled=None - for i in range(num_edges): - if(test_presamples.iloc[i:i+1,:].equals(val_presamples.iloc[i:i+1,:])): - if(i==0): - break - liss=[uu for uu in range(len(val_presamples)) if uu != i] - qq=random.choice(liss) - new_u=val_presamples.iloc[qq:qq+1,:] - val_presamples.iloc[qq:qq+1,:]=val_presamples.iloc[i:i+1,:] - #print(qq) - #print(new_u) - val_presamples.iloc[i:i+1,:]=new_u - #if already_sampled is not None: - """ - sampler=pd.concat([test_presamples.iloc[i:i+1,:], already_sampled], ignore_index=True).drop_duplicates() - merged = pd.merge(val_edges,sampler, how='outer', indicator=True) - result = merged[merged['_merge'] == 'left_only'] - result = result.drop('_merge', axis=1) - if(result.empty): - qq=random.randint(list(range(len(already_sampled)))) - new_u=already_sampled.loc[qq:qq+1,:] - already_sampled.loc[qq:qq+1,:]=test_presamples.iloc[i:i+1,:] - result=new_u - - - already_sampled=pd.concat([already_sampled,result.sample(1)],ignore_index=True) - else: - merged = pd.merge(val_edges,test_presamples.iloc[i:i+1,:], how='outer', indicator=True) - result = merged[merged['_merge'] == 'left_only'] - result = result.drop('_merge', axis=1) - if(result.empty): - continue_loop=False - break - already_sampled=result.sample(1) - - """ - - - #val_presamples=already_sampled - prev_num_edg=num_edges - #print(val_presamples) - if num_edges%k != 0: - - num_edges= num_edges-num_edges%k if num_edges%k > 0 else num_edges - - test_precalc=math.floor( - prev_num_edg * mask['test']) if math.floor( - prev_num_edg * mask['test']) >= 1 else 1 - val_precalc= math.floor( - prev_num_edg * mask['val']) if math.floor( - prev_num_edg * mask['val']) >= 1 else 1 - for i in range(k): - test_left_mask[i], val_left_mask[i] = 0.0, 0.0 - #print(math.ceil(i*(prob)*num_edges)) - #print(num_edges) - #print(math.ceil((i+1)*(prob)*num_edges)) - test_presample, val_presample = test_presamples.iloc[math.floor(i*(prob)*num_edges):math.floor((i+1)*(prob)*num_edges),:],val_presamples.iloc[math.floor(i*(prob)*num_edges):math.floor((i+1)*(prob)*num_edges),:] - print(val_presample) - if(test_presample.shape[0] == 0): - test_presample=test_edges.sample(1) - if(val_presample.shape[0] == 0): - val_presamples = val_edges.sample(1) - - #print([len(val_presample),val_precalc]) - test_edg[i] = test_presample.iloc[:test_precalc, :] if len(test_presample) >= test_precalc else test_presample - test_left_mask[i] = (test_precalc - len(test_presample)) / prev_num_edg if len(test_presample) < test_precalc else 0 - val_edg[i] = val_presample.iloc[:val_precalc, :] if len(val_presample) >= val_precalc else val_presample - val_left_mask[i]= (val_precalc - len(val_presample)) / prev_num_edg if len(val_presample) < val_precalc else 0 - merged = pd.merge(test_edges,test_edg[i], how='outer', indicator=True) - test_edges = merged[merged['_merge'] == 'left_only'] - test_edges = test_edges.drop('_merge', axis=1) - merged = pd.merge(val_edges,val_edg[i], how='outer', indicator=True) - val_edges = merged[merged['_merge'] == 'left_only'] - val_edges = val_edges.drop('_merge', axis=1) - #print(val_presample.loc[:val_precalc, :]) - #print(val_edg) - if(test_left_mask[0]>=1 or val_left_mask[0]>=1 and prev_num_edg-num_edges > 0): - indexes=random.sample(range(k),prev_num_edg-num_edges) - print(test_edges.shape) - test_edges=test_edges.sample(1) - val_edges=val_edges.sample(1) - num=0 - for i in range(k): - if(i not in indexes): - merged = pd.merge(edges,val_edg[i], how='outer', indicator=True) - edges_filtered = merged[merged['_merge'] == 'left_only'] - edges_filtered = edges_filtered.drop('_merge', axis=1) - merged=pd.merge(edges_filtered,test_edg[i], how='outer', indicator=True) - edges_filtered = merged[merged['_merge'] == 'left_only'] - edges_filtered = edges_filtered.drop('_merge', axis=1) - test_edg[i]=pd.concat([test_edg[i],edges_filtered],ignore_index=True) if test_left_mask[i]>=1 else test_edg[i] - - val_edg[i]=pd.concat([val_edg[i],edges_filtered]) if val_left_mask[i]>=1 else val_edg[i] - else: - test_edg[i]=pd.concat([test_edg[i],test_edges.iloc[num:num+1]],ignore_index=True) if test_left_mask[i]>=1 else test_edg[i] - val_edg[i]=pd.concat([val_edg[i],val_edges.iloc[num:num+1]],ignore_index=True) if val_left_mask[i]>=1 else val_edg[i] - test_left_mask[i]-=1 / prev_num_edg - val_left_mask[i]-=1 / prev_num_edg - if continue_loop: - previous_loops_test=None - previous_loops_val=None - for i in range(k): - if val_edg[i] is not None: - merged = pd.merge(edges,val_edg[i], how='outer', indicator=True) - edges_filtered = merged[merged['_merge'] == 'left_only'] - edges_filtered = edges_filtered.drop('_merge', axis=1) - new_edges = edges_filtered if val_edg[i] is not None else edges - if test_edg[i] is not None: - merged = pd.merge(new_edges,test_edg[i], how='outer', indicator=True) - edges_filtered = merged[merged['_merge'] == 'left_only'] - edges_filtered = edges_filtered.drop('_merge', axis=1) - new_edges = edges_filtered if test_edg[i] is not None else new_edges - if i>0: - previous_loops_val=None - previous_loops_test=None - val_test=pd.concat([test_edg[i],val_edg[i]], ignore_index=True).drop_duplicates() - for y in range(i): - merged = pd.merge(test_edg[y],val_test, how='outer', indicator=True) - edges_filtered = merged[merged['_merge'] == 'left_only'] - edges_filtered = edges_filtered.drop('_merge', axis=1) - if not edges_filtered.empty: - previous_loops_val=pd.concat([previous_loops_val,edges_filtered.sample(1)], ignore_index=True).drop_duplicates() if previous_loops_val is not None else edges_filtered.sample(1) - merged = pd.merge(val_edg[y],val_test, how='outer', indicator=True) - edges_filtered = merged[merged['_merge'] == 'left_only'] - edges_filtered = edges_filtered.drop('_merge', axis=1) - if not edges_filtered.empty: - previous_loops_test=pd.concat([previous_loops_test,edges_filtered.sample(1)], ignore_index=True).drop_duplicates() if previous_loops_test is not None else edges_filtered.sample(1) - #predges = new_edges.copy() - #print(test_left_mask[i]) - if test_left_mask[i] > 0 and math.floor(prev_num_edg* test_left_mask[i]) > 0 and new_edges.shape[0] != 0: - if previous_loops_test is not None: - merged = pd.merge(new_edges,previous_loops_test, how='outer', indicator=True) - edges_loop_filter = merged[merged['_merge'] == 'left_only'] - edges_loop_filter = edges_loop_filter.drop('_merge', axis=1) - new_edges_test = edges_loop_filter if previous_loops_test is not None else new_edges - new_edges_test = predges if new_edges_test.empty else new_edges - last_test_sample = new_edges_test.sample(math.floor(prev_num_edg* test_left_mask[i])) - test_edg[i] = pd.concat([test_edg[i], last_test_sample], ignore_index=True) - merged = pd.merge(new_edges,last_test_sample, how='outer', indicator=True) - new_edges = merged[merged['_merge'] == 'left_only'] - new_edges = new_edges.drop('_merge', axis=1) - - if val_left_mask[i] > 0 and math.floor(prev_num_edg* val_left_mask[i]) >0 and new_edges.shape[0] != 0: - if previous_loops_val is not None: - merged = pd.merge(new_edges,previous_loops_val, how='outer', indicator=True) - edges_loop_filter = merged[merged['_merge'] == 'left_only'] - edges_loop_filter = edges_loop_filter.drop('_merge', axis=1) - new_edges_val = edges_loop_filter if previous_loops_val is not None else new_edges - new_edges_val = predges if new_edges_val.empty else new_edges - last_val_sample = new_edges_val.sample(math.floor(prev_num_edg* val_left_mask[i])) - val_edg[i] = pd.concat([val_edg[i], last_val_sample], ignore_index=True) - merged = pd.merge(new_edges,last_val_sample, how='outer', indicator=True) - new_edges = merged[merged['_merge'] == 'left_only'] - new_edges = new_edges.drop('_merge', axis=1) - #print([test_edg[i].shape[0],val_edg[i].shape[0],new_edges.shape[0]]) - if(new_edges.shape[0] == 0): - if(test_edg[i].shape[0] >1): - new_edges_test = test_edg[i].sample(1) - new_edges=pd.concat([new_edges, new_edges_test], ignore_index=True) - merged = pd.merge(test_edg[i],new_edges_test, how='outer', indicator=True) - test_edg[i] = merged[merged['_merge'] == 'left_only'] - test_edg[i] = test_edg[i].drop('_merge', axis=1) - if(val_edg[i].shape[0] >1): - new_edges_val = val_edg[i].sample(1) - new_edges=pd.concat([new_edges, new_edges_val], ignore_index=True) - #print("-------pre--------------") - #print(val_edg[i]) - merged = pd.merge(val_edg[i],new_edges_val, how='outer', indicator=True) - val_edg[i] = merged[merged['_merge'] == 'left_only'] - val_edg[i] = val_edg[i].drop('_merge', axis=1) - #print("-------post--------------") - #print(val_edg[i]) - if(len(val_edg[i]['src'].tolist())<=0 or len(test_edg[i]['src'].tolist()) <=0 or len(new_edges['src'].tolist())<=0): - iteration=max_iterations+1 - #print([test_edg[i], val_edg[i]]) - #print(new_edges) - merged = pd.merge(val_edg[i],test_edg[i], how='outer', indicator=True) - val_edg[i] = merged[merged['_merge'] == 'left_only'] - val_edg[i] = val_edg[i].drop('_merge', axis=1) - merged = pd.merge(new_edges,test_edg[i], how='outer', indicator=True) - new_edges = merged[merged['_merge'] == 'left_only'] - new_edges = new_edges.drop('_merge', axis=1) - merged = pd.merge(new_edges,val_edg[i], how='outer', indicator=True) - new_edges = merged[merged['_merge'] == 'left_only'] - new_edges = new_edges.drop('_merge', axis=1) - new_edges=new_edges.drop_duplicates() - val_edg[i]=val_edg[i].drop_duplicates() - test_edg[i]=test_edg[i].drop_duplicates() - - val_list[i][(ss, etype, dd)] = (torch.tensor(val_edg[i]['src'].tolist()), - torch.tensor(val_edg[i]['dst'].tolist())) - test_list[i][(ss, etype, dd)] = (torch.tensor(test_edg[i]['src'].tolist()), - torch.tensor(test_edg[i]['dst'].tolist())) - train_list[i][(ss, etype, dd)] = (torch.tensor(new_edges['src'].tolist()), - torch.tensor(new_edges['dst'].tolist())) - - return val_list, test_list, train_list, iteration - - -def create_folds(graph, mask, k, max_iterations=100): - """ - Create folds for cross-validation. - - Args: - graph (dgl.DGLHeteroGraph): Input graph. - mask (dict): Mask parameters. - k (int): Number of folds. - max_iterations (int, optional): Maximum number of iterations. Defaults to 100. - - Returns: - list: List of training graphs. - list: List of validation graphs. - list: List of test graphs. - - """ - train_list, val_list, test_list = [dict() for _ in range(k)], [dict() for _ in range(k)], [dict() for _ in range(k)] - - continue_loop = True - for ss, etype, dd in graph.canonical_etypes: - iteration = max_iterations + 1 - continue_outer_loop = True - val_etype_lis, test_etype_lis, train_etype_lis, iteration = random_fold_loop(graph, mask, k,etype,ss,dd, max_iterations) - for i in range(0,k): - #print(list(val_etype_lis[i].values())) - val_list[i][list(val_etype_lis[i].keys())[0]]=list(val_etype_lis[i].values())[0] - test_list[i][list(test_etype_lis[i].keys())[0]]=list(test_etype_lis[i].values())[0] - train_list[i][list(train_etype_lis[i].keys())[0]]=list(train_etype_lis[i].values())[0] - num_nodes_dict = {ntype: graph.num_nodes(ntype) for ntype in graph.ntypes} - train_list = [dgl.heterograph(train_list[i], num_nodes_dict=num_nodes_dict) for i in range(k)] - val_list = [dgl.heterograph(val_list[i], num_nodes_dict=num_nodes_dict) for i in range(k)] - test_list = [dgl.heterograph(test_list[i], num_nodes_dict=num_nodes_dict) for i in range(k)] - - return train_list, val_list, test_list -def check_folds(edges,fold_list_train,fold_list_val,fold_list_test): - """ - Check consistency of generated folds. - - Args: - edges: Edge data. - fold_list_train: List of training graphs. - fold_list_val: List of validation graphs. - fold_list_test: List of test graphs. - - Returns: - bool: True if folds are consistent, False otherwise. - - """ - check=True - for t,v,ts in zip(fold_list_train,fold_list_val,fold_list_test): - for etype in edges.etypes: - src,dst=edges.edges(etype=etype) - tsrc,tdst=t.edges(etype=etype) - vsrc,vdst=v.edges(etype=etype) - tssrc,tsdst=ts.edges(etype=etype) - edg = pd.DataFrame({'src': src.squeeze(), 'dst': dst.squeeze()}) - tedg = pd.DataFrame({'src': tsrc.squeeze(), 'dst': tdst.squeeze()}) - vedg = pd.DataFrame({'src': vsrc.squeeze(), 'dst': vdst.squeeze()}) - tsedg = pd.DataFrame({'src': tssrc.squeeze(), 'dst': tsdst.squeeze()}) - comp=pd.concat([tedg,vedg,tsedg], ignore_index=True).drop_duplicates() - concated=pd.concat([tedg,vedg,tsedg]) - print(etype) - print(concated[concated.duplicated(subset=['src','dst'],keep=False)]) - if (not concated[concated.duplicated(subset=['src','dst'],keep=False)].empty): - dis, dru, pat, pro, ddi = loadNodes(True) - if(edges.to_canonical_etype(etype)[0] == 'disease' or edges.to_canonical_etype(etype)[2] == 'disease'): - if(edges.to_canonical_etype(etype)[2] == 'disease'): - dru=dis - print("entra") - if(edges.to_canonical_etype(etype)[0] == 'drug' or edges.to_canonical_etype(etype)[2] == 'drug'): - if(edges.to_canonical_etype(etype)[0] == 'drug'): - dis=dru - if(edges.to_canonical_etype(etype)[0] == 'pathway' or edges.to_canonical_etype(etype)[2] == 'pathway'): - if(edges.to_canonical_etype(etype)[0] == 'pathway'): - dis=pat - else: - dru=pat - if(edges.to_canonical_etype(etype)[0] == 'protein' or edges.to_canonical_etype(etype)[2] == 'protein'): - if(edges.to_canonical_etype(etype)[0] == 'protein'): - dis=pro - else: - dru=pro - print("sale") - if(edges.to_canonical_etype(etype)[0] == 'drug_drug_interaction' or edges.to_canonical_etype(etype)[1] == 'drug_drug_interaction'): - if(edges.to_canonical_etype(etype)[0] == 'drug_drug_interaction'): - dis=ddi - else: - dru=ddi - dis = dis.rename(lambda x: "node_id" if x != "name" else x, axis=1) - dru = dru.rename(lambda x: "node_id" if x != "name" else x, axis=1) - dis = dis.drop_duplicates(keep='last').reset_index(drop=True).rename_axis("id").reset_index() - dru = dru.drop_duplicates(keep='last').reset_index(drop=True).rename_axis("id").reset_index() - dis_dict = dis[["id","node_id"]].set_index("id").to_dict()["node_id"] - dru_dict = dru[["id","node_id"]].set_index("id").to_dict()["node_id"] - dup=concated[concated.duplicated(subset=['src','dst'],keep=False)] - dup['src']=dup.src.map(dis_dict) - dup['dst']=dup.dst.map(dru_dict) - dup.to_csv("links_duplicados_"+etype+".csv") - check=False - #print(comp) - #print(edg) - common=pd.merge(edg, comp, on=['src', 'dst'], how='left', indicator=True) - print((common['_merge'] == 'left_only').any()) - if((common['_merge'] == 'left_only').any()): - check=False - return check -class NewNegativeSampler(object): - - def __init__(self): - pass - def __call__(self,g,eids_dict,k=1): - neg_edges=dict() - for etype in g.etypes: - neg_edges[g.to_canonical_etype(etype)]=dgl.sampling.global_uniform_negative_sampling(g,g.num_edges(etype)*k,etype=etype,exclude_self_loops=False) - return neg_edges -def remove_repetitive(g): - new_edges={} - print(g) - for src,etype,dst in g.canonical_etypes: - if not etype.startswith("rev_"): - new_edges[(src,etype,dst)]=g.edges(etype=etype) - elif src!=dst: - new_edges[(src,etype,dst)]=g.edges(etype=etype) - return dgl.heterograph(new_edges,num_nodes_dict={ntype: g.num_nodes(ntype) for ntype in g.ntypes}) - -def get_weights(comp_g,g): - dic=dict() - for (srct,etype,dstt) in g.canonical_etypes: - src, dst = g.edges(etype=etype) - # Convert source and destination node IDs to the appropriate data type - src_ids = src.to(torch.int64) - dst_ids = dst.to(torch.int64) - # Get edge IDs in the compact graph - #print(src_ids) - #print(comp_g.edata["w"]) - edge_ids=None - edge_ids = comp_g.edge_ids(src_ids.tolist(), dst_ids.tolist(), etype=etype).flatten() - - # Extract edge weights based on edge IDs - edge_weights = torch.index_select(comp_g.edata["w"][g.to_canonical_etype(etype)], 0, edge_ids) - # Create a DataFrame to store source, destination, and edge weights - data = pd.DataFrame({"src": src, "dst": dst, "w": edge_weights}) - # Convert edge weights to a tensor and store in a dictionary - dic[etype] = torch.tensor(data["w"].tolist()) - print("outside loop") - return dic - -def print_topk(edges,pred,k=10): - dis, dru = loadNodes(False) - - dis = dis.rename(lambda x: "node_id" if x != "name" else x, axis=1) - dru = dru.rename(lambda x: "node_id" if x != "name" else x, axis=1) - - dis = dis.drop_duplicates(keep='last').reset_index(drop=True).rename_axis("id").reset_index() - dru = dru.drop_duplicates(keep='last').reset_index(drop=True).rename_axis("id").reset_index() - - dis_dict = dis[["id","node_id"]].set_index("id").to_dict()["node_id"] - dru_dict = dru[["id","node_id"]].set_index("id").to_dict()["node_id"] - dis_name = dis[["name","node_id"]].set_index("node_id").to_dict()["name"] - dru_name = dru[["name","node_id"]].set_index("node_id").to_dict()["name"] - - src,dst = edges - sorted_pred, indices = pred.squeeze().detach().cpu().sort(descending=True) - sorted_pred, indices = sorted_pred[:k] if k < len(pred) else sorted_pred , indices[:k] if k < len(pred) else indices - sorted_src,sorted_dst = src[indices],dst[indices] - output_dataframe = pd.DataFrame({"dis":sorted_src.squeeze().cpu(),"dru":sorted_dst.squeeze().cpu()}) - output_dataframe["dis"] = output_dataframe.dis.map(dis_dict) - output_dataframe["dru"] = output_dataframe.dru.map(dru_dict) - output_dataframe["dis_name"] = output_dataframe.dis.map(dis_name) - output_dataframe["dru_name"] = output_dataframe.dru.map(dru_name) - output_dataframe['pred'] = sorted_pred.squeeze().detach().cpu() - - print(output_dataframe) - -sweep_configuration_fold = { - "method": "random", - "name": "sweep_fold", - "metric": {"goal": "maximize", "name": "auroc"}, - "parameters": { - "batch_size": {"values": [16, 32, 64]}, - "epochs": {"values":[1000, 2000,2500]}, - "lr": {"values":[0.1,1e-2,1e-3,1e-4,1e-5]}, - "eps": {"values":[1e-10,1e-6,1e-2,1,1e2]}, - "dropout": {"values": [0.0,0.2,0.5,0.7,1.0]}, - "weight_decay": {"values":[1.0 ,0.1,0.01,0.001,0.0001]}, - "betas1": {"values":[0.9999,0.9,0.3,0.1]}, - "betas2": {"values":[0.9999,0.9,0.3,0.1]} - }, -} -def train_disease_graph(graph, final_bed, disease_graph,h_dim, test=False,epochs=100,dropout=0.2): - disease_s_gnn= nn.Sequential( - dglnn.SAGEConv(final_bed.shape[1],h_dim), - nn.ReLU(), - nn.Dropout(dropout), - dglnn.SAGEConv(h_dim,h_dim), - nn.Sigmoid() - ) - los=nn.CosineEmbeddingLoss() - if not test: - optimizer = torch.optim.AdamW(disease_s_gnn.parameters(), lr=learning_rate,weight_decay=weight_decay,eps=eps,betas=betas) - for e in epochs: - metrics,subgraph=create_distances_subgraph - emb=disease_s_gnn(subgraph,metrics.extend(final_bed.shape)+final_bed) - emb2=disease_s_gnn(disease_graph,disease_graph.ndata['feat'].extend(final_bed.shape)+final_bed) - loss=los(emb,emb2,torch.tensor(emb.shape[0])) - loss.backward() - optimizer.step() - torch.save(disease_s_gnn.state_dict(), "disease_gnn.pt") - else: - disease_s_gnn.load_state_dict(torch.load("disease_gnn.pt",weights_only=False)) - emb=disease_s_gnn(disease_graph,disease_graph.ndata['feat'].extend(final_bed.shape)+final_bed) - return emb -def main(): - """ - Main script for training and evaluating a Graph Neural Network using DGL. - - This script initializes the model, loads and preprocesses the dataset, creates folds for cross-validation, - trains the model, and evaluates its performance on the test set. - - Args: - None - - Returns: - None - """ - # Set the device to GPU if available, otherwise use CPU - print(f"Training with DGL built-in RGCN module") - runi = wandb.init(project='fold_sweep') - - lr = wandb.config.lr - eps = wandb.config.eps - epochs = wandb.config.epochs - dropout= wandb.config.dropout - wd=wandb.config.weight_decay - betas=(wandb.config.betas1,wandb.config.betas2) - # log metrics to wandb - - # load and preprocess dataset - k=10 - train_subgraph_node_embedder=True - #epochs=100 - transform=dgl.AddReverse(copy_edata=True,sym_new_etype=True) - # log metrics to wandb - - # load and preprocess dataset - full=True - data,kk,test_dis_dru,comp_g,w_dict = create_heterograph(True) - nu_dict=dict((key,data.num_nodes(key)) for key in data.ntypes) - #print(nu_dict) - #final_test_g=dgl.heterograph({("disease", "dis_dru_the", "drug"): test_dis_dru },num_nodes_dict=comp_nodes_num_dict) - #print(comp_g.edges(etype="dis_dru_the")) - #print(final_test_g.edges(etype="dis_dru_the")) - #final_test_g=final_test_g.cpu() - #final_test_g.edata['w']=torch.ones(final_test_g.num_edges("dis_dru_the")) - - data.edata['w']=get_weights(comp_g,data) - mask={'train': 0.5, 'val': 0.4, 'test': 0.1} - emb_dru=load_drug_emb(device,mask,h_dim=400) if not os.path.exists("embeddings/drug_emb.tsv") else pd.read_csv("embeddings/drug_emb.tsv",sep='\t') - emb_dru=emb_dru.replace(r'\n',' ', regex=True) - - emb_dru['name']=emb_dru.name.map(kk[0]) - import ast - for ia,row in emb_dru.iterrows(): - emb_dru.at[ia,'embedding']=np.asarray(np.matrix(row['embedding']).reshape(1,400)).squeeze() - emb_dru=emb_dru.sort_values(by=['name']) - #print(emb_dru['embedding']) - new_form=None - - emb_dru=torch.tensor(np.vstack(emb_dru['embedding'])) - if(full): - emb_prot=load_protein_emb(kk[1]) - data.ndata['feat']['protein']=emb_prot - #data.ndata['feat']['drug']=emb_dru - #data = FB15k237Dataset(reverse=False) - #g = data[0] - in_dim=400 - data=transform(data) - #print(data) - comp_g=transform(comp_g) - test_nids={ntype: torch.tensor([[1.0]*in_dim for u in range(data.num_nodes(ntype))]) for ntype in data.ntypes} - #if test_nids == data.ndata['h']: - # print("gottem") - g=data - #print(g) - num_nodes = g.num_nodes() - #num_rels = data.num_rels - num_rels=len(g.etypes) - #train_g = get_subset_g(g, g.edata["train_mask"], num_rels) - #sampler = dgl.dataloading.NeighborSampler(sampling) - #sampler = NegativeSamplerHet(g, k=1, neg_share=False) - sampler=NewNegativeSampler() - print("Full graph: {:.4f}".format(comp_g.num_edges("dis_dru_the"))) - print("Train graph: {:.4f}".format(g.num_edges("dis_dru_the"))) - #print("Test graph: {:.4f}".format(final_test_g.num_edges("dis_dru_the"))) - #if(full): - # comp_g.ndata['h']['protein'],comp_rand_graph.ndata['h']['protein']=emb_prot,emb_prot - #comp_g.ndata['h']['drug'],comp_rand_graph.ndata['h']['drug']=emb_dru,emb_dru - #print(neg_graph) - mask={'train': 0.5, 'val': 0.4, 'test': 0.1} - - #data.ndata['h']['drug'],test_nids['drug']=emb_dru,emb_dru - prc_list=list() - aur_list=list() - #test_mask = g.edata["test_mask"] - best_aur=0 - best_prc=0 - best_model=None - eids_dict = {etype: torch.tensor(range(0,len(g.edges(etype=etype)[0]))) for etype in g.etypes} - neg_samples = sampler(g, eids_dict) - num_nodes={ntype: g.num_nodes(ntype) for ntype in g.ntypes} - neg_graph=dgl.heterograph(neg_samples,num_nodes_dict=num_nodes) - #print(neg_graph) - mask={'train': 0.8, 'val': 0.15, 'test': 0.05} - pos_train_list,pos_val_list,pos_test_list=create_folds(g,mask,k) - print(check_folds(g,pos_train_list,pos_val_list,pos_test_list)) - previous_embeddings=torch.ones(g.num_nodes("diseases"),200) - final_p_embedding=torch.ones(g.num_nodes("diseases"),200) - for u in tqdm.tqdm(range(epochs),total=epochs): - #pos_g_dict=disjoint_split_hetero(g) - #pos_train_list,pos_val_list,pos_test_list= pos_g_dict['train'],pos_g_dict['val'],pos_g_dict['test'] - #pos_train_list,pos_val_list,pos_test_list=get_subsets(g,mask) - matrix, metrics2 = get_incidence(pos_train_list[u], ntype, applySubcosts) - disease_graph= dgl.graph(torch.nonzero(matrix, as_tuple=True),num_nodes=pos_train_list[u].num_nodes("disease")) - final_bed = subgraph_node_embedder(pos_train_list[u], disease_graph, "disease", 200,previous_embeddings, epochs=200, epochs_sub=100, - epochs_seq=100,train_se=True, train_so=True, steps=2) - disease_embs = train_disease_graph(pos_train_list[u], final_bed, disease_graph, test=False,epochs=100) - - g.ndata['disease'] = final_bed - disease_graph.ndata["disease"] = final_bed - print(g.edata['w']) - pos_train_list[u].edata['w']=get_weights(g,pos_train_list[u]) - pos_val_list[u].edata['w']=get_weights(g,pos_val_list[u]) - pos_test_list[u].edata['w']=get_weights(g,pos_test_list[u]) - #if(full): - # pos_train_list[ji].ndata['h']['protein'],pos_val_list[ji].ndata['h']['protein'],pos_test_list[ji].ndata['h']['protein']=emb_prot,emb_prot,emb_prot - #pos_train_list[ji].ndata['h']['drug'],pos_val_list[ji].ndata['h']['drug'],pos_test_list[ji].ndata['h']['drug']=emb_dru,emb_dru,emb_dru - print("Train subgraph: {:.4f}".format(pos_train_list[u].num_edges("dis_dru_the"))) - print("Training test subgraph: {:.4f}".format(pos_val_list[u].num_edges("dis_dru_the"))) - print("Validation subgraph: {:.4f}".format(pos_test_list[u].num_edges("dis_dru_the"))) - #print(check_folds(g,pos_train_list,pos_val_list,pos_test_list)) - #for ji in range(len(pos_train_list)): - # if(full): - # neg_train_list[ji].ndata['h']['protein'],neg_val_list[ji].ndata['h']['protein'],neg_test_list[ji].ndata['h']['protein']=emb_prot,emb_prot,emb_prot - # neg_train_list[ji].ndata['h']['drug'],neg_val_list[ji].ndata['h']['drug'],neg_test_list[ji].ndata['h']['drug']=emb_dru,emb_dru,emb_dru - #test_g = get_subset_g(g,g.edata["train_mask"], num_rels, bidirected=True) - #test_g.edata["norm"] = dgl.norm_by_dst(test_g).unsqueeze(-1) - #test_nids = {ntype: torch.tensor(np.arange(0,g.num_nodes(ntype))).to(device) for ntype in g.ntypes} - if(full): - test_nids['protein']=emb_prot - test_nids['drug']=emb_dru - #test_nids=node_features - test_mask = g.edata["test_mask"] - #sampler=dgl.dataloading.NeighborSampler([-1 for etype in pos_train_list[u]],prefetch_edge_feats={etype: ['w'] for etype in pos_train_list[u].etypes}) - #dataloader=dgl.dataloading.DistEdgeDataloader(pos_train_list[u],{etype: pos_train_list[u].edges(etype=etype) for etype in pos_train_list[u].etypes}, sampler) - - # Prepare data for metric computation - triplets = {etype: torch.stack([pos_val_list[u].edges(etype=etype)[0],pos_val_list[u].edges(etype=etype)[1]],dim=1) for src,etype,dst in pos_test_list[u].canonical_etypes} - test_mask={etype: list(range(0,len(triplets[etype]))) for src,etype,dst in pos_test_list[u].canonical_etypes} - g.ndata['feat']=test_nids - # create RGCN model - print(len(g.ndata['feat']['protein'][0])) - best_mod,prc_l,aur_l=train_fold_with_samples(pos_val_list[u],disease_embs,g,model,pos_test_list[u],comp_rand_g,test_nids,test_mask,triplets,device,h_dim=30,w=g.edata["w"],epochs=epochs,learning_rate=lr,weight_decay=wd,eps=eps) - prc_list.append(mean(prc_l)) - aur_list.append(mean(aur_l)) - previous_embeddings=best_mod["disease"] - if(mean(aur_list)>best_aur and mean(prc_list) > best_prc): - best_model=best_mod - best_prc=mean(prc_list) - best_aur=mean(aur_list) - runi.log( - { - "auroc": mean(aur_l), - "prc": mean(prc_l), - "mean_prc": mean(prc_list), - "mean_auroc": mean(aur_list) - - } - ) - - matrix, metrics2 = get_incidence(graph, ntype, applySubcosts) - disease_graph = dgl.graph(torch.nonzero(matrix, as_tuple=True),num_nodes=graph.num_nodes("disease")) - final_bed = subgraph_node_embedder(g, disease_graph, "disease", 200,final_p_embedding, steps=2) - disease_embs=train_disease_graph(graph,final_bed,disease_graph,30,test=True) - g.ndata['feat']['disease'] = final_bed - model = LinkPredict(g, g.num_nodes, in_dim=200, h_dim=30, dropout=dropout, device=device).to(device) - best_mod, prc_l, aur_l = train_fold_with_samples(g, disease_embs, g, model, pos_test_list[u], - comp_rand_g, - test_nids, test_mask, triplets, device, h_dim=30, w=g.edata["w"], - epochs=epochs, learning_rate=lr, weight_decay=wd, eps=eps) - runi.log( - { - "final_auroc": aur_l, - "final_prc": prc_l - - } - ) - runi.finish() - disease_embedding=best_mod["disease"]+final_bed+disease_embs/3 - # testing - print("Final Testing...") - torch.save(disease_embedding, "tensor.pt") - -if __name__=='__main__': - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - sweep_id = wandb.sweep(sweep=sweep_configuration_fold, project="disease_sweep") - wandb.agent(sweep_id, function=main, count=1) diff --git a/disease_src/disease-base-simple.py b/disease_src/embedding_construction.py similarity index 100% rename from disease_src/disease-base-simple.py rename to disease_src/embedding_construction.py diff --git a/disease_src/disease_test_ODP.py b/disease_src/test_repoDB.py similarity index 100% rename from disease_src/disease_test_ODP.py rename to disease_src/test_repoDB.py diff --git a/disease_src/pruebaLinkHetAdrian_ODP.py b/disease_src/test_repoDB_induction.py similarity index 100% rename from disease_src/pruebaLinkHetAdrian_ODP.py rename to disease_src/test_repoDB_induction.py