import gc from cuml.metrics import pairwise_distances import scipy import dgl import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import tqdm from dgl.data.knowledge_graph import FB15k237Dataset from dgl.dataloading import GraphDataLoader from dgl.nn.pytorch import RelGraphConv import pandas as pd import math from sklearn.metrics import roc_auc_score import sklearn.metrics as metrics import wandb import random import dgl.nn.pytorch as dglnn from matplotlib import pyplot as plt from statistics import mean from DRAGON.get_embedding_prot import export_protein_emb import pysmiles.read_smiles as read_smiles import os from models import GS,MLPPredS, HashGNN import mysql.connector import numpy as np import torch_geometric as pyg from collections import defaultdict import scipy.stats as st import matplotlib.pyplot as plt import dgl.function as fn import DRAGON.heterograph_construction as heterograph_construction from deepsnap.hetero_gnn import forward_op import torch.multiprocessing as mp from concurrent.futures import ThreadPoolExecutor from sklearn.neighbors import NearestNeighbors from sklearn.manifold import TSNE from sklearn.cluster import DBSCAN,HDBSCAN, SpectralClustering from sklearn.preprocessing import StandardScaler from umap import UMAP from sklearn.metrics import normalized_mutual_info_score, adjusted_rand_score,silhouette_score import requests from itertools import combinations from pycirclize import Circos from pycirclize.utils import ColorCycler from Bio.Phylo.TreeConstruction import DistanceMatrix from Bio.Phylo.TreeConstruction import DistanceTreeConstructor import seaborn as sns #UMLS_API_KEY = "YOUR_UMLS_API_KEY" UMLS_API_KEY = "5bb2169b-aacc-4e3c-a8b6-c0edda7d37cd" def get_icd10cm_from_cui(cui,visited=None): """Fetch the ICD-10-CM codes for a given UMLS CUI using UMLS API.""" if visited is None: visited = set() if cui in visited: return None visited.add(cui) url = f"https://uts-ws.nlm.nih.gov/rest/content/current/CUI/{cui}/atoms?apiKey={UMLS_API_KEY}&sabs=ICD10CM" print(f"Requesting URL: {url}") response = requests.get(url) if response.status_code != 200: print(f"Error: {response.status_code} - {response.text}") return None data = response.json() results = data.get("result", []) #print("Raw JSON Response:", results[0]["code"].split("/")[-1]) #icd10_codes = [item["code"].split("/")[-1] for item in results if "code" in item.keys()] icd10_codes = [(item["name"],item["code"].split("/")[-1]) for item in results if "name" in item.keys()] if icd10_codes and icd10_codes != []: return icd10_codes else: print("No direct ICD-10-CM match found. Trying to find parent...") #def get_mth_name(cui) codes,ba=get_icd10cm_parent(cui) if not ba and codes is not None: code=codes codes=None for cod in code: codes=get_icd10cm_parent(cod, visited) if codes is not None: break return codes def get_icd10cm_parent(cui): """Find the parent ICD-10-CM code if no direct mapping exists.""" url = f"https://uts-ws.nlm.nih.gov/rest/content/current/CUI/{cui}/relations?apiKey={UMLS_API_KEY}&sabs=ICD10CM" print(f"Requesting URL: {url}") response = requests.get(url) if response.status_code != 200: print(f"Error: {response.status_code} - {response.text}") return None data = response.json() results = data.get("result", []) print("Raw JSON Response:", results[0]["relatedId"].split("/")[-1]) #parents = [item["relatedId"].split("/")[-1] for item in results if item["relationLabel"] in ["PAR", "RB"]] parents = [(item["relatedIdName"],item["relatedId"].split("/")[-1]) for item in results if item["relationLabel"] in ["PAR", "RB"]] par_ex=True if not parents: url = f"https://uts-ws.nlm.nih.gov/rest/content/current/CUI/{cui}/relations?apiKey={UMLS_API_KEY}&sabs=MTH" print(f"Requesting URL: {url}") response = requests.get(url) if response.status_code != 200: print(f"Error: {response.status_code} - {response.text}") return None data = response.json() results = data.get("result", []) print("Raw JSON Response:", results[0]["relatedId"].split("/")[-1]) parents = [item["relatedId"].split("/")[-1] for item in results if item["relationLabel"] in ["PAR", "RO"]] par_ex = False return (parents,par_ex) if parents else (None,False) def batch_process_cuis(cui_list, max_threads=10): """Process multiple CUIs in parallel.""" results = {} with ThreadPoolExecutor(max_workers=max_threads) as executor: future_to_cui = {executor.submit(get_icd10cm_from_cui, cui): cui for cui in cui_list} for future in future_to_cui: cui = future_to_cui[future] try: result = future.result() results.update({cui: result}) except Exception as e: print(f"Error processing CUI {cui}: {e}") results[cui] = None return results def get_none_n(l): url = f"https://uts-ws.nlm.nih.gov/rest/content/current/CUI/{l}/atoms?apiKey={UMLS_API_KEY}&language=ENG" print(f"Requesting URL: {url}") response = requests.get(url) if response.status_code != 200: print(f"Error: {response.status_code} - {response.text}") return None data = response.json() results = data.get("result", []) icd10_codes = [item["name"] for item in results if "name" in item.keys()] if icd10_codes and icd10_codes != []: return icd10_codes[0] else: return l def get_none_name(cui_list, max_threads=10): """Process multiple CUIs in parallel.""" results = {} with ThreadPoolExecutor(max_workers=max_threads) as executor: future_to_cui = {executor.submit(get_none_n, cui): cui for cui in cui_list} for future in future_to_cui: cui = future_to_cui[future] try: result = future.result() results.update({cui: result}) except Exception as e: print(f"Error processing CUI {cui}: {e}") results[cui] = None return results def plot_hebundling(data,links,dist,codes,extension=""): #chr_bed_file, cytoband_file, chr_links = load_eukaryote_example_dataset("mm10") # Initialize Circos from BED chromosomes #circos = Circos({d:100 for d in data}, space=1) #circos.text("Mus musculus\n(mm10)", deg=315, r=150, size=12) # Add cytoband tracks from cytoband file #circos.add_cytoband_tracks((95, 100), cytoband_file) # Create chromosome color mapping links["values"]=dist data2=links[["diseaseA","diseaseB","values"]] new_data2=[] for _, row in data2.iterrows(): new_data2.append((row['diseaseA'], row['diseaseB'], row['values'])) new_data2.append((row['diseaseB'], row['diseaseA'], row['values'])) new_data2=pd.DataFrame(new_data2,columns=["diseaseA", "diseaseB", "values"]) data2=new_data2.drop_duplicates(ignore_index=True) df_pivot=data2.pivot_table(columns="diseaseA",index="diseaseB",values="values") df_pivot=df_pivot.loc[:,~df_pivot.columns.duplicated()].copy() df_pivot=df_pivot.drop_duplicates() #df_pivot= df_pivot.fillna(-1) distance_matrix= df_pivot.to_numpy() pivot_df = df_pivot.copy() np.fill_diagonal(pivot_df.values, 0) #pivot_df.values=np.nan_to_num(pivot_df.values) print(pivot_df.values[~np.isnan(pivot_df.values)]) print(pivot_df) pivot_df=pivot_df.fillna(-1) lower_triangle = [[pivot_df.values.tolist()[i][j] for j in range(i+1)] for i in range(len(pivot_df.values.tolist()))] matrix=DistanceMatrix(pivot_df.index.to_list(), lower_triangle) tree=DistanceTreeConstructor() tree=tree.upgma(matrix) circos,tv= Circos.initialize_from_tree(tree) new_color_list=[] for i in pivot_df.index.to_list(): if new_color_list != []: not_found=True for col in new_color_list: if codes[i][0] == codes[col[0]][0]: col.append(i) not_found=False if not_found: new_color_list.append([i]) else: new_color_list.append([i]) ColorCycler.set_cmap("hsv") chr_names = [s.name for s in circos.sectors] colors = ColorCycler.get_color_list(len(new_color_list)) for i in range(len(colors)): #print(len(new_color_list[i])) #print(colors[i]) tv.set_node_line_props(new_color_list[i], color=colors[i]) for u in new_color_list[i]: print(u) tv.set_node_label_props(u, color=colors[i]) #chr_name2color = {name: color for name, color in zip(chr_names, colors)} #print(circos.sectors) # Plot chromosome name & xticks """ for sector in circos.sectors: sector.text(sector.name, r=120, size=10, color=chr_name2color[sector.name]) label_position = "outside" # Plot chromosome link position = 0 max_strength={d: 0 for d in data} for link,dst in zip(links.iterrows(),dist): print(dst) max_strength[link[1]["diseaseA"]]+=dst.item() max_strength[link[1]["diseaseB"]]+=dst.item() regionX={d: 0 for d in data} for link,dst in zip(links.iterrows(),dist): #print(link) r1rx=dst/max_strength[link[1]["diseaseA"]] * 100 r2rx=dst/max_strength[link[1]["diseaseA"]] * 100 region1 = (link[1]["diseaseA"], regionX[link[1]["diseaseA"]],regionX[link[1]["diseaseA"]]+r1rx) region2 = (link[1]["diseaseB"], regionX[link[1]["diseaseB"]], regionX[link[1]["diseaseB"]]+r2rx) color = chr_name2color[link[1]["diseaseA"]] if link[1]["diseaseA"] != link[1]["diseaseB"]: print(region1) print(region2) circos.link(region1, region2, color=color) """ fig = circos.savefig("circos_"+extension+".png") def heat_map(data,dist,extension=""): data["values"]=dist data2=data[["diseaseA","diseaseB","values"]] new_data2=[] for _, row in data2.iterrows(): new_data2.append((row['diseaseA'], row['diseaseB'], row['values'])) new_data2.append((row['diseaseB'], row['diseaseA'], row['values'])) new_data2=pd.DataFrame(new_data2,columns=["diseaseA", "diseaseB", "values"]) data2=new_data2.drop_duplicates(ignore_index=True) df_pivot=data2.pivot_table(columns="diseaseA",index="diseaseB",values="values") df_pivot=df_pivot.loc[:,~df_pivot.columns.duplicated()].copy() df_pivot=df_pivot.drop_duplicates() #df_pivot= df_pivot.fillna(-1) distance_matrix= df_pivot.to_numpy() pivot_df = df_pivot.copy() np.fill_diagonal(pivot_df.values, 0) #pivot_df.values=np.nan_to_num(pivot_df.values) print(pivot_df.values[~np.isnan(pivot_df.values)]) print(pivot_df) pivot_df=pivot_df.fillna(-1) #df_pivot=data2.pivot("diseaseA","diseaseB","values") sns.heatmap(pivot_df,annot=True,cmap="crest") plt.savefig("heatmap_"+extension+".png") class DiseaseHierarchy: def __init__(self, api_key, sources): """ Initializes the class with a UMLS API key and a list of sources (SABs). :param api_key: UMLS API Key for authentication :param sources: List of source vocabularies (e.g., ["SNOMEDCT_US", "ICD10CM"]) """ self.api_key = api_key self.base_url = "https://uts-ws.nlm.nih.gov" self.sources = set(sources) # Store as a set for quick lookup self.ticket = self.get_auth_ticket() def get_auth_ticket(self): """ Retrieves a service ticket for API authentication. """ auth_url = "https://utslogin.nlm.nih.gov/cas/v1/api-key" response = requests.post(auth_url, data={"apikey": self.api_key}) if response.status_code == 201: return response.headers["location"] else: raise ValueError(f"Failed to authenticate: {response.text}") def get_hierarchies(self, query): """ Queries UMLS for hierarchy information based on a CUI or term. :param query: Concept Unique Identifier (CUI) or name :return: Dictionary of retrieved hierarchy information """ search_url = f"{self.base_url}/rest/search/current" params = { "string": query, "apiKey": self.api_key, "ticket": self.ticket, "sabs": ",".join(self.sources), # Filter by specified sources "returnIdType": "code" } response = requests.get(search_url, params=params) if response.status_code != 200: raise ValueError(f"Failed to query UMLS: {response.text}") results = response.json().get("result", []) filtered_hierarchies = {} print("Raw JSON Response:", data) for item in results: cui = item.get("ui") if not cui or cui == "NONE": continue # Skip invalid results source = item.get("rootSource") if source not in self.sources: continue # Ensure the result comes from an allowed source hierarchy_info = self.get_hierarchy_for_cui(cui, source) if hierarchy_info: filtered_hierarchies[cui] = hierarchy_info return filtered_hierarchies def get_hierarchy_for_cui(self, cui, source): """ Retrieves hierarchy information for a given CUI from a specific source. :param cui: Concept Unique Identifier (CUI) :param source: Source vocabulary (e.g., "SNOMEDCT_US") :return: Dictionary with hierarchy details """ hierarchy_url = f"{self.base_url}/rest/content/current/source/{source}/CUI/{cui}/relations" params = {"apiKey": self.api_key, "ticket": self.ticket} response = requests.get(hierarchy_url, params=params) if response.status_code != 200: return None relations = response.json().get("result", []) hierarchy_data = [] for relation in relations: relation_type = relation.get("relationLabel", "") related_id = relation.get("relatedId", "") related_name = relation.get("relatedName", "") hierarchy_data.append({ "relation": relation_type, "related_id": related_id, "related_name": related_name, "source": source }) return hierarchy_data if hierarchy_data else None def loadNodes(): """ Load nodes from files. Args: full (bool): Flag indicating whether to load full data. Returns: pd.DataFrame: Dataframe containing disease nodes. pd.DataFrame: Dataframe containing drug nodes. (Optional) pd.DataFrame: Dataframe containing patient nodes. (Optional) pd.DataFrame: Dataframe containing protein nodes. (Optional) pd.DataFrame: Dataframe containing DDI nodes. """ dis = pd.read_csv('data/nodes/dis.tsv', sep='\t') return dis def load_data(): """ Load data for building training/testing graphs. Args: full (bool): Flag indicating whether to load full data. Returns: pd.DataFrame: Dataframe containing disease-drug-theatre links. pd.DataFrame: Dataframe containing disease-symptom links. (Optional) pd.DataFrame: Dataframe containing disease-patient links. (Optional) pd.DataFrame: Dataframe containing disease-protein links. (Optional) pd.DataFrame: Dataframe containing drug-drug links. (Optional) pd.DataFrame: Dataframe containing drug-protein links. (Optional) pd.DataFrame: Dataframe containing drug-symptom (indication) links. (Optional) pd.DataFrame: Dataframe containing drug-symptom (side effect) links. (Optional) pd.DataFrame: Dataframe containing protein-patient links. (Optional) pd.DataFrame: Dataframe containing protein-protein links. (Optional) pd.DataFrame: Dataframe containing DDI-phenotype links. (Optional) pd.DataFrame: Dataframe containing DDI-drug links. """ dis_dru_the = pd.read_csv('data/links/dis_dru_the.tsv', sep='\t') dis_sym = pd.read_csv('data/links/dis_sym.tsv', sep='\t') dis_pat = pd.read_csv('data/links/dis_pat.tsv', sep='\t') dis_pro = pd.read_csv('data/links/dis_pro.tsv', sep='\t') dru_sym_ind = pd.read_csv('data/links/dru_sym_ind.tsv', sep='\t') dru_sym_sef = pd.read_csv('data/links/dru_sym_sef.tsv', sep='\t') ddi_phe = pd.read_csv('data/links/ddi_phe.tsv', sep='\t') return dis_dru_the, dis_sym, dis_pat, dis_pro, dru_sym_ind, dru_sym_sef, ddi_phe def jaccard_similarity(set1, set2): intersection = len(set1.intersection(set2)) union = len(set1.union(set2)) return intersection / union if union != 0 else 0 def calculate_mean_jaccard(df, group_column, value_column): # Group by the specified column and convert values to sets grouped = df.groupby(group_column)[value_column].apply(set).reset_index() # Create a DataFrame to store the similarity results similarity_results = pd.DataFrame(columns=['dis1', 'dis2', 'jaccard_similarity']) # Calculate Jaccard similarity for each pair of groups for (group1, set1), (group2, set2) in combinations(grouped[[group_column, value_column]].values, 2): similarity = jaccard_similarity(set1, set2) similarity_results = similarity_results._append({ 'dis1': group1, 'dis2': group2, 'jaccard_similarity': similarity }, ignore_index=True) # Calculate the mean Jaccard similarity return similarity_results def load_naive_method(identifiers): di=loadNodes() dis_dru_the,dis_sym,dis_pat,dis_pro,dru_sym_ind,dru_sym_sef,ddi_phe=load_data() di=di.rename(lambda x: "node_id" if x != "name" else x, axis=1).drop_duplicates().reset_index() di=di.drop_duplicates(keep='last').reset_index(drop=True).rename_axis("id").reset_index() di_dict= di[["id","node_id"]].set_index("node_id").to_dict()["id"] dis_dru_the['dis']=dis_dru_the.dis.map(di_dict) dis_dru_the=dis_dru_the[dis_dru_the['dis'].isin(identifiers)] dis_sym['dis']=dis_sym.dis.map(di_dict) dis_sym=dis_sym[dis_sym['dis'].isin(identifiers)] dis_sym['sym']=dis_sym.sym.map(di_dict) sym_dis=dis_sym[dis_sym['sym'].isin(identifiers)] dis_pro['dis']=dis_pro.dis.map(di_dict) dis_pro=dis_pro[dis_pro['dis'].isin(identifiers)] dis_pat['dis']=dis_pat.dis.map(di_dict) dis_pat=dis_pat[dis_pat['dis'].isin(identifiers)] dru_sym_ind['sym']=dru_sym_ind.sym.map(di_dict) dru_sym_ind=dru_sym_ind[dru_sym_ind['sym'].isin(identifiers)] dru_sym_sef['sym']=dru_sym_sef.sym.map(di_dict) dru_sym_sef=dru_sym_sef[dru_sym_sef['sym'].isin(identifiers)] ddi_phe['dis']=ddi_phe.phe.map(di_dict) ddi_phe=ddi_phe[ddi_phe['dis'].isin(identifiers)] dis_dru_jaccard=calculate_mean_jaccard(dis_dru_the,"dis","dru") dis_sym_jaccard=calculate_mean_jaccard(dis_sym,"dis","sym") sym_dis_jaccard=calculate_mean_jaccard(sym_dis,"sym","dis") dis_pro_jaccard=calculate_mean_jaccard(dis_pro,"dis","pro") dis_pat_jaccard=calculate_mean_jaccard(dis_pat,"dis","pat") dru_sym_ind_jaccard=calculate_mean_jaccard(dru_sym_ind,"sym","dru") dru_sym_sef_jaccard=calculate_mean_jaccard(dru_sym_sef,"sym","dru") ddi_phe_jaccard=calculate_mean_jaccard(ddi_phe,"dis","ddi") df = pd.concat([dis_dru_jaccard,dis_sym_jaccard, sym_dis_jaccard, sym_dis_jaccard, dis_pat_jaccard, dis_pro_jaccard, dru_sym_ind_jaccard, dru_sym_sef_jaccard, ddi_phe_jaccard], ignore_index=True) df=df.groupby(['dis1', 'dis2'],as_index=False)['jaccard_similarity'].mean().reset_index() print(max(df['dis1'])) print(len(df['jaccard_similarity'])) print(max(df['dis2'].astype(np.int32))) print(max(identifiers)) sp_matr=scipy.sparse.coo_matrix((df.jaccard_similarity, (np.searchsorted(identifiers,df['dis1'].to_numpy()), np.searchsorted(identifiers,df['dis2'].to_numpy()))),shape=(len(identifiers),len(identifiers))) print(sp_matr.shape) return sp_matr def plot_clusters(labels,data,path,dissimilar="no precomputed",UMA=False): #print("NaN values:", np.isnan(data.detach().cpu()).sum()) #print("Infinite values:", np.isinf(data.detach().cpu()).sum()) if dissimilar=="precomputed": umap = UMAP(metric="cosine", n_neighbors=30) #embedding = TSNE(n_components=2, metric= dissimilar, method="barnes_hut", init="random",perplexity=3.0, random_state=42) X_transformed = umap.fit_transform(data) else: X_transformed = data # Plot clusters #print("passes") plt.figure(figsize=(8, 6)) unique_labels = set(labels) colors = plt.cm.get_cmap("tab10", len(unique_labels)) for label in unique_labels: mask = (labels == label) plt.scatter(X_transformed[mask, 0], X_transformed[mask, 1], label=f"Cluster {label}" if label != -1 else "Noise", alpha=0.7) plt.legend() plt.title("DBSCAN Clustering of CUIs") plt.savefig(path) def load_diseases_dict(): dis = pd.read_csv('data/nodes/dis.tsv', sep='\t') dis=dis.rename(lambda x: "node_id" if x != "name" else x, axis=1) dis=dis.drop_duplicates(keep='last').reset_index(drop=True).rename_axis("id").reset_index() disease_map= dis[["id","node_id"]].set_index("node_id").to_dict()["id"] rev_disease_map= dis[["id","node_id"]].set_index("id").to_dict()["node_id"] return disease_map,rev_disease_map def mixed_info(true_clus,ref_clus): dif_clus=np.unique(true_clus) non_revisitable=set() punctuation=0 sizes=[] banned_clusters=[] #point_to_revise=[u for u in range(len(true_clus))] for ind,clus in enumerate(dif_clus): full_index=np.where(true_clus == clus)[0] index = np.array(list(set(full_index)-non_revisitable)) if len(index) == 0: continue inx,count = np.unique(ref_clus[index], return_counts=True) rin=[i for i, x in enumerate(inx) if x not in banned_clusters] inx=inx[rin] count= count[rin] if len(count) == 0: punctuation += 0 continue nx_loss= len(true_clus)-count.max() * (1-(len(dif_clus) /len(true_clus))) if count.max() >= (len(true_clus)/len(dif_clus)) else count.max()*len(dif_clus) / len(true_clus) punctuation += count.max() * np.exp(-(count.max()-(len(true_clus)/len(dif_clus))**2*-np.log1p(nx_loss)*len(count))) #sizes.append(exp(-(s-(len(true_clus)/len(dif_clus)))^2/(2*()^2))) #print(inx) banned_clusters.append(inx[np.argmax(count)]) non_revisitable.update(np.where(index == inx[np.argmax(count)])[0]) return punctuation/(len(true_clus)) def main(mode,reduction_method="umap",eps=0.5,metric="euclidean",cluster='dbscan', clus=8): np.seterr(divide='ignore', invalid='ignore') if mode == 0 or mode == -1: transE_emb=torch.load("disease_src/diseaseBase/tensor_transE.pt") elif mode == 1: transR_emb=torch.load("disease_src/diseaseBase/tensor_transR_s.pt") elif mode == 2: nltransR_emb=torch.load("disease_src/diseaseBase/tensor_NLtransR.pt") elif mode == 3: node2Vec_emb=torch.load("disease_src/diseaseBase/tensor_Node2Vec.pt") elif mode == 4: gae_emb=torch.load("disease_src/diseaseBase/tensor_GAE.pt") elif mode == 5: hash_emb=torch.load("disease_src/diseaseBase/tensorHashGNN.pt") elif mode == 6: basicm_emb=torch.load("disease_src/diseaseBase/tensorBasicModel.pt") prx=pd.read_csv("proximity_dis_dis_filt.csv") dis_map,rev_dis_map=load_diseases_dict() prx["diseaseAid"]=prx.disease_A.map(dis_map) prx["diseaseBid"]=prx.disease_B.map(dis_map) disA=prx.diseaseAid.unique() disB=prx.diseaseBid.unique() uniqueData=np.unique(np.append(disA,disB)) print(len(uniqueData)) if mode == 7: sp_matr= load_naive_method(uniqueData) #_,interA,interB=np.intersect1d(disA,disB) if reduction_method == "pca": red_method = PCA(n_components=2) elif reduction_method == "tsne": red_method = TSNE(n_components=2, metric='euclidean') elif reduction_method == "umap": red_method = UMAP(n_components=2, ) if mode == 0: if cluster == "spectral": data_d=transE_emb.detach().cpu().numpy().astype(np.float32) elif cluster == "dbscan": data_d=StandardScaler().fit_transform(transE_emb.detach().cpu()) data_d=red_method.fit_transform(data_d) sp_matr=scipy.sparse.coo_matrix((np.diag(pairwise_distances(data_d[prx["diseaseAid"].to_numpy()], Y=data_d[prx["diseaseBid"].to_numpy()],metric=metric)), (np.searchsorted(uniqueData,prx["diseaseAid"].to_numpy()),np.searchsorted(uniqueData,prx["diseaseBid"].to_numpy()))),shape=(len(uniqueData),len(uniqueData)))#shape=(transE_emb.shape[0],transE_emb.shape[0])).tocsr() if cluster == "dbscan": db=DBSCAN(metric="precomputed", eps=eps, min_samples=2).fit_predict(sp_matr) elif cluster == "spectral": db=SpectralClustering(n_clusters=clus,affinity='precomputed').fit_predict(sp_matr) #db = DBSCAN(eps=0.5, min_samples=2).fit_predict(data_d) print (" Trans E labels") #plot_clusters(db,data_d,"TransE_clusters.png",dissimilar="no precomputed") plot_clusters(db,data_d[uniqueData],"TransE_clusters.png",dissimilar="precomputed") print (" Trans E cluster done") elif mode == 1: if cluster == "spectral": data_d=transR_emb.detach().cpu().numpy().astype(np.float32) elif cluster == "dbscan": data_d=StandardScaler().fit_transform(transR_emb.detach().cpu()) data_d=red_method.fit_transform(data_d) sp_matr=scipy.sparse.coo_matrix((np.diag(pairwise_distances(data_d[prx["diseaseAid"].to_numpy()], Y=data_d[prx["diseaseBid"].to_numpy()],metric=metric)), (np.searchsorted(uniqueData,prx["diseaseAid"].to_numpy()),np.searchsorted(uniqueData,prx["diseaseBid"].to_numpy()))),shape=(len(uniqueData),len(uniqueData))) if cluster == "dbscan": db=DBSCAN(metric="precomputed", eps=eps, min_samples=2).fit_predict(sp_matr) elif cluster == 'spectral': #db = DBSCAN(eps=0.5, min_samples=2).fit_predict(data_d) db=SpectralClustering(n_clusters=clus,affinity='precomputed').fit_predict(sp_matr) print(" Trans R labels") plot_clusters(db,data_d[uniqueData],"TransR_clusters.png",dissimilar="no precomputed",UMA=True) print(" Trans R cluster done") elif mode == 2: if cluster == 'spectral': data_d=nltransR_emb.detach().cpu().numpy().astype(np.float32) elif cluster == 'dbscan': data_d=StandardScaler().fit_transform(nltransR_emb.detach().cpu()) data_d=red_method.fit_transform(data_d) sp_matr=scipy.sparse.coo_matrix((np.diag(pairwise_distances(data_d[prx["diseaseAid"].to_numpy()], Y=data_d[prx["diseaseBid"].to_numpy()],metric=metric)), (np.searchsorted(uniqueData,prx["diseaseAid"].to_numpy()),np.searchsorted(uniqueData,prx["diseaseBid"].to_numpy()))),shape=(len(uniqueData),len(uniqueData))) if cluster == 'dbscan': db=DBSCAN(metric="precomputed", eps=eps, min_samples=2).fit_predict(sp_matr) elif cluster == 'spectral': db=SpectralClustering(n_clusters=clus,affinity='precomputed').fit_predict(sp_matr) #db = DBSCAN(eps=0.5, min_samples=2).fit_predict(data_d) print(" Non Linear Trans R labels") plot_clusters(db,data_d[uniqueData],"NonLinearTransR_clusters.png",dissimilar="no precomputed",UMA=True) print(" Non Linear Trans R cluster done") elif mode == 3: if cluster == 'spectral': data_d=node2Vec_emb.detach().cpu().numpy().astype(np.float32) elif cluster == 'dbscan': data_d=StandardScaler().fit_transform(node2Vec_emb.detach().cpu()) data_d=red_method.fit_transform(data_d) sp_matr=scipy.sparse.coo_matrix((np.diag(pairwise_distances(data_d[prx["diseaseAid"].to_numpy()], Y=data_d[prx["diseaseBid"].to_numpy()],metric=metric)), (np.searchsorted(uniqueData,prx["diseaseAid"].to_numpy()),np.searchsorted(uniqueData,prx["diseaseBid"].to_numpy()))),shape=(len(uniqueData),len(uniqueData))) if cluster == 'dbscan': db=DBSCAN(metric="precomputed", eps=eps, min_samples=2).fit_predict(sp_matr) elif cluster == 'spectral': db=SpectralClustering(n_clusters=clus,affinity='precomputed').fit_predict(sp_matr) #db = DBSCAN(eps=0.5, min_samples=2).fit_predict(data_d) print(" Node2Vec labels") plot_clusters(db,data_d[uniqueData],"Node2Vec_clusters.png",dissimilar="no precomputed",UMA=True) print(" Node2Vec cluster done") elif mode == 4: if cluster == 'spectral': data_d=gae_emb.detach().cpu().numpy().astype(np.float32) elif cluster == 'dbscan': data_d=StandardScaler().fit_transform(gae_emb.detach().cpu()) data_d=red_method.fit_transform(data_d) sp_matr=scipy.sparse.coo_matrix((np.diag(pairwise_distances(data_d[prx["diseaseAid"].to_numpy()], Y=data_d[prx["diseaseBid"].to_numpy()],metric=metric)), (np.searchsorted(uniqueData,prx["diseaseAid"].to_numpy()),np.searchsorted(uniqueData,prx["diseaseBid"].to_numpy()))),shape=(len(uniqueData),len(uniqueData))) if cluster == 'dbscan': db=DBSCAN(metric="precomputed", eps=eps, min_samples=2).fit_predict(sp_matr) elif cluster == 'spectral': db=SpectralClustering(n_clusters=clus,affinity='precomputed').fit_predict(sp_matr) #db = DBSCAN(eps=0.5, min_samples=2).fit_predict(data_d) print(" GAE labels") plot_clusters(db,data_d[uniqueData],"GAE_clusters.png",dissimilar="no precomputed",UMA=True) print(" GAE cluster done") elif mode == 5: if cluster == 'spectral': data_d=hash_emb.detach().cpu().numpy().astype(np.float32) elif cluster == 'dbscan': data_d=StandardScaler().fit_transform(hash_emb.detach().cpu()) data_d=red_method.fit_transform(data_d) sp_matr=scipy.sparse.coo_matrix((np.diag(pairwise_distances(data_d[prx["diseaseAid"].to_numpy()], Y=data_d[prx["diseaseBid"].to_numpy()],metric=metric)), (np.searchsorted(uniqueData,prx["diseaseAid"].to_numpy()),np.searchsorted(uniqueData,prx["diseaseBid"].to_numpy()))),shape=(len(uniqueData),len(uniqueData))) if cluster == 'dbscan': db=DBSCAN(metric="precomputed", eps=eps, min_samples=2).fit_predict(sp_matr) elif cluster == 'spectral': db=SpectralClustering(n_clusters=clus,affinity='precomputed').fit_predict(sp_matr) #db = DBSCAN(eps=0.5, min_samples=2).fit_predict(data_d) print(" HashGNN labels") plot_clusters(db,data_d[uniqueData],"HashGNN_clusters.png",dissimilar="no precomputed", UMA=True) print(" HashGNN cluster done") elif mode == 6: if cluster == 'spectral': data_d=basicm_emb.detach().cpu().numpy().astype(np.float32) elif cluster == 'dbscan': data_d=StandardScaler().fit_transform(basicm_emb.detach().cpu()) data_d=red_method.fit_transform(data_d) sp_matr=scipy.sparse.coo_matrix((np.diag(pairwise_distances(data_d[prx["diseaseAid"].to_numpy()], Y=data_d[prx["diseaseBid"].to_numpy()],metric=metric)), (np.searchsorted(uniqueData,prx["diseaseAid"].to_numpy()),np.searchsorted(uniqueData,prx["diseaseBid"].to_numpy()))),shape=(len(uniqueData),len(uniqueData))) if cluster == 'dbscan': db=DBSCAN(metric="precomputed", eps=eps, min_samples=2).fit_predict(sp_matr) elif cluster == 'spectral': db=SpectralClustering(n_clusters=clus,affinity='precomputed').fit_predict(sp_matr) #db= DBSCAN(eps=0.5, min_samples=2).fit_predict(data_d) print(" Basic Model labels") plot_clusters(db, data_d[uniqueData],"BasicModel_clusters.png",dissimilar="no precomputed",UMA=True) print(" Basic Model cluster done") if mode == -1: sp_matr=scipy.sparse.coo_matrix((prx["Closest distance"].to_numpy(), (np.searchsorted(uniqueData,prx["diseaseAid"].to_numpy()),np.searchsorted(uniqueData,prx["diseaseBid"].to_numpy()))),shape=(len(uniqueData),len(uniqueData)))#shape=(transE_emb.shape[0],transE_emb.shape[0])).tocsr() if cluster == 'dbscan': db=DBSCAN(metric="precomputed", eps=eps, min_samples=2).fit_predict(sp_matr) elif cluster == 'spectral': db=SpectralClustering(n_clusters=clus,affinity='precomputed').fit_predict(sp_matr) plot_clusters(db,sp_matr,"Proximity_clusters.png",dissimilar="precomputed") if mode == 7: print(sp_matr.shape) if cluster == 'dbscan': db=DBSCAN(metric="precomputed", eps=eps, min_samples=2).fit_predict(sp_matr) elif cluster == 'spectral': db=SpectralClustering(n_clusters=clus,affinity='precomputed').fit_predict(sp_matr) plot_clusters(db,sp_matr,"Naive_clusters.png",dissimilar="precomputed") sil_c=-np.inf print(db) if len(np.unique(db)) > 1: sil_c=silhouette_score(sp_matr,db) return db,sil_c,sp_matr def plot_distances(mode,metric="euclidean"): np.seterr(divide='ignore', invalid='ignore') if mode == 0 or mode == -1: data_d=torch.load("disease_src/diseaseBase/tensor_transE.pt") extension="transE" elif mode == 1: data_d=torch.load("disease_src/diseaseBase/tensor_transR_s.pt") extension="transR" elif mode == 2: data_d=torch.load("disease_src/diseaseBase/tensor_NLtransR.pt") extension="nl_transE" elif mode == 3: data_d=torch.load("disease_src/diseaseBase/tensor_Node2Vec_s.pt") extension="node2vec" elif mode == 4: data_d=torch.load("disease_src/diseaseBase/tensor_GAE.pt") extension="gae" elif mode == 5: data_d=torch.load("disease_src/diseaseBase/tensorHashGNN.pt") extension="hashGNN" elif mode == 6: data_d=torch.load("disease_src/diseaseBase/tensorBasicModel.pt") extension="basicm" elif mode == 7: data_d=torch.load("disease_src/diseaseBase/tensorEdgeBModel.pt") extension="edgebm" prx=pd.read_csv("proximity_dis_dis_filt.csv") data_d=data_d.detach().cpu().numpy().astype(np.float64) dis_map,rev_dis_map=load_diseases_dict() print(len(dis_map.keys())) print(len(data_d)) assert len(dis_map.keys()) == len(data_d) prx["diseaseAid"]=prx.disease_A.map(dis_map) prx["diseaseBid"]=prx.disease_B.map(dis_map) disA=prx.diseaseAid.unique() disB=prx.diseaseBid.unique() dist=np.diag(pairwise_distances(data_d[prx["diseaseAid"].to_numpy()], Y=data_d[prx["diseaseBid"].to_numpy()],metric=metric)) print(dist) uniqueData=np.unique(np.append(disA,disB)) cui_list=pd.Series(uniqueData).map(rev_dis_map).to_list() cuiL=batch_process_cuis(cui_list, max_threads=10) print(cuiL) new_dict=get_none_name([k for k,v in cuiL.items() if v is None]) #print(new_dict) red_method = UMAP(n_components=3, ) transformed_emb = red_method.fit_transform(data_d[uniqueData]) fig = plt.figure() ax2 = fig.add_subplot(projection='3d') ax2.scatter(transformed_emb[:,0],transformed_emb[:,1],transformed_emb[:,2], marker='o', linewidth=0.5) plt.title("embeddings "+extension) fig.savefig('disease_emb_'+extension +'.png') cui_map={k: v[0][0] if v is not None else None for k,v in cuiL.items()} #prx["diseaseA"]=prx.disease_A.apply(lambda x: cui_map.get(x, x)) #prx["diseaseB"]=prx.disease_B.apply(lambda x: cui_map.get(x, x)) prx["diseaseA"] = prx.disease_A.map(cui_map) prx["diseaseB"] = prx.disease_B.map(cui_map) prx=prx.mask(prx.astype(object).eq('None')).dropna() unique_cuis=np.unique(np.concatenate([prx.diseaseA.unique(),prx.diseaseB.unique()])).tolist() code={v[0][0]: v[0][1] for _,v in cuiL.items() if v is not None} #code.update({new_dict[k]: "--" for k,v in cuiL.items() if v is None}) dist=np.diag(pairwise_distances(data_d[prx["diseaseAid"].to_numpy()], Y=data_d[prx["diseaseBid"].to_numpy()],metric=metric)) group_cuis= plot_hebundling(unique_cuis,prx,dist,code,extension=extension) #heat_map(prx,dist,extension=extension) print(data_d[prx["diseaseAid"].to_numpy()].shape) disA=prx.diseaseAid.unique() disB=prx.diseaseBid.unique() uniqueData_red=np.unique(np.append(disA,disB)) uniqueData_tags=[cuiL[rev_dis_map[y]][0][1] for y in uniqueData_red] unique_data=pd.DataFrame({"data":uniqueData_red, "tags":uniqueData_tags}) unique_data=unique_data.drop_duplicates(subset=["tags"]) dist=pairwise_distances(data_d[unique_data["data"].to_numpy()], Y=data_d[unique_data["data"].to_numpy()],metric=metric) #sp_matr=scipy.sparse.coo_matrix((dist, # (np.searchsorted(uniqueData_red,prx["diseaseAid"].to_numpy()),np.searchsorted(uniqueData_red,prx["diseaseBid"].to_numpy()))),shape=(len(uniqueData_red),len(uniqueData_red))) #sp_matr = sp_matr.todense() #sp_matr[sp_matr == 0] = np.nan #np.fill_diagonal(sp_matr,0) print(dist) dist = scipy.spatial.distance.squareform(dist) link=scipy.cluster.hierarchy.linkage(dist) plt.figure(figsize=(15,15)) dendro=scipy.cluster.hierarchy.dendrogram(link,labels = unique_data["tags"].to_numpy(),leaf_rotation=90., leaf_font_size=15,) plt.savefig('dendrogram_'+extension+'.png', dpi=520, format='png', bbox_inches='tight') unique_cuis=np.unique(np.array(list(dis_map.keys()))).tolist() #cuiL = batch_process_cuis(unique_cuis, max_threads=4) code=[v[0][1] for _,v in cuiL.items() if v is not None] cuiL = {k: v for k,v in cuiL.items() if v is not None} data_x = np.array([data_d[dis_map[k]] for k in cuiL.keys()]) dic = {"CUI": list(cuiL.keys()),"embedding": data_x.tolist(), "ICM10CM": code} resul=pd.DataFrame(dic) resul.to_csv("embedding_df_"+extension+".csv") if __name__ == "__main__": list_of_db=[None for _ in range(9)] sp=[None for _ in range(7)] for i in range(8): plot_distances(i) max_indexes=0 final_sp=[-np.inf for _ in range(8)] best_eps=[[0,-np.inf,-np.inf] for _ in range(0,9)] cl='spectral' if cl == 'dbscan': for eps in np.arange(0.1,0.5,0.1): for raps in np.arange(0.1,0.5,0.1): list_of_db_2=[] for i in range (8): ra_db,sil_sc=main(i,eps=raps) #,metric='chebyshev') list_of_db_2.append(ra_db) if sil_sc > best_eps[i][1]: best_eps[i]=[raps,sil_sc] list_of_db[i]=ra_db ref_db,sil_sc=main(7,eps=eps) if sil_sc > best_eps[-1][1]: best_eps[-1]=[eps,sil_sc] list_of_db[-1]=ref_db mean_indexes=0 elif cl == 'spectral': for cls in range(3,50,5): list_of_db_2=[] for i in range (8): ra_db,sil_sc,sp_mat=main(i,clus=cls,cluster='spectral') #,metric='chebyshev') sp[i]=sp_mat list_of_db_2.append(ra_db) if sil_sc > best_eps[i][1]: best_eps[i]=[cls,sil_sc] list_of_db[i]=ra_db ref_db,sil_sc,sp_mat=main(7, clus=cls,cluster='spectral') #, metric='chebyshev') if sil_sc > best_eps[-1][1]: best_eps[-1]=[cls,sil_sc] list_of_db[-1]=ref_db sp_res=[] for mat in sp: if mat is not None: print(mat) print(sp_mat) sp_res.append(np.sum((mat-sp_mat)**2)**0.5) else: sp_res.append(-np.inf) for i,y in enumerate(sp_res): if y < final_sp[i] or final_sp[i] <= -np.inf: final_sp[i]=y #for u in list_of_db_2: # mean_indexes=adjusted_rand_score(ref_db,u) #mean_indexes/=len(list_of_db_2) #if mean_indexes > max_indexes: # list_of_db=list_of_db_2 for i,u in enumerate(list_of_db): if u is not None: print(" the nmi "+str(i)+" is of : " +str(normalized_mutual_info_score(ref_db,u))) print(" the mixed info "+ str(i)+" is of: "+ str(mixed_info(ref_db,u))) print(" the ajusted rand score "+str(i)+" is of: "+str(adjusted_rand_score(ref_db,u))) for i,u in enumerate(best_eps): print("best eps of "+str(i) + " is: "+str(u)) print(final_sp) for i,r in enumerate(final_sp): print("best correlation of "+str(i)+": "+ str(r))