analysis_visualization.py 4.51 KB
Newer Older
Laura Masa's avatar
Laura Masa committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

def filter_cd_zscore(pnd_df):
    """
    This function filters the rows of pnd_df based on the values of 'closest distance'
    that fall below Q3 from the known treatments.
    
    Input Parameters:
    - pnd_df: DataFrame containing the data with 'closest distance' column.
    
    Returns:
    - filtered_df: DataFrame containing only the rows where 'closest distance' is between Q1 and Q3.
    """
    # Calculate Q3 of the Treatment
    filtered_yes = pnd_df[pnd_df['Treatment'] == 'yes']
    q3 = filtered_yes['Closest distance'].quantile(0.75)
    filtered_df =pnd_df[(pnd_df['Dc_zscore'] < -0.15) & (pnd_df['Closest distance'] <= q3) & (pnd_df['Treatment'] == 'unknown')]
    
    return filtered_df



def rep_pnd(pnd_df,filtered_pnd):
    """
    Generates side-by-side boxplots to compare 'Closest distance' and 'Personalized Network Distance (PND)' metrics
    across three different treatment status groups.

    Input Parameters:
    - pnd_df (pd.DataFrame): DataFrame containing drug data with columns: 
      - 'Treatment': The treatment status of the drugs ('yes' or 'unknown').
      - 'Closest distance': A numeric measure of the closest distance metric for the drugs.
      - 'PND': A numeric measure of Personalized Network Distance for the drugs.
    - filtered_pnd (pd.DataFrame): A subset of pnd_df filtered for a specific condition, with the same columns as pnd_df.

    Returns:
    - None: The function saves a boxplot comparison as a PNG file and displays the plot.
    """
    drugs_with_disease = pnd_df[(pnd_df['Treatment'] == 'yes')]
    drugs_without_disease = pnd_df[(pnd_df['Treatment'] == 'unknown')]
    combined_data = pd.concat([drugs_with_disease.assign(Treatment='Treatment'), drugs_without_disease.assign(Treatment='All unknown'),filtered_pnd.assign(Treatment='Filtered unknown for DR')])
 
    # Combine the two datasets into a single subplot
    fig, axes = plt.subplots(1, 2, figsize=(14, 6))  # Create a subplot with 1 row and 2 columns
 
    sns.boxplot(x='Treatment', y='Closest distance', data=combined_data, hue='Treatment', ax=axes[0], palette={'Treatment': '#FF7A7A', 'All unknown': '#79C4FF', 'Filtered unknown for DR': '#CAF0F8'}, dodge=False, medianprops=dict(linewidth=2))
    axes[0].set_ylabel('Closest distance', fontsize=12)
    axes[0].set_xlabel('')
    for label in axes[0].get_xticklabels():
        label.set_fontsize(12)
    axes[0].legend([], [], frameon=False)
    
    # Plot the boxplot with 'PND' on the right
    sns.boxplot(x='Treatment', y='PND', data=combined_data, hue='Treatment', ax=axes[1], palette={'Treatment': '#FF7A7A', 'All unknown': '#79C4FF', 'Filtered unknown for DR': '#CAF0F8'}, dodge=False, medianprops=dict(linewidth=2))
    axes[1].set_ylabel('Personalized Network Distance ($\mathregular{PND}$)', fontsize=12)
    axes[1].set_xlabel('')
    for label in axes[1].get_xticklabels():
        label.set_fontsize(12)
    axes[1].legend([], [], frameon=False)
    plt.tight_layout()  # Adjust the layout of the subplot to avoid overlap
 
    plt.savefig('../results/pnd_closest_distance_boxplot_filtered.png', dpi=300)
    plt.show()





def get_gsm_ids_for_lowest_pnd_drugs(pnd_df, drugs_list, num_gsm=3):
    """
    This function extracts the gsm_id for the given list of drugs with the lowest PND values.
    
    Input Parameters:
    - pnd_df (pd.DataFrame): DataFrame containing drug data with columns: 
      - 'Treatment': The treatment status of the drugs ('yes' or 'unknown').
      - 'Closest distance': A numeric measure of the closest distance metric for the drugs.
      - 'PND': A numeric measure of Personalized Network Distance for the drugs.
    - drugs_list: List of drugs with the lowest PND values.
    - num_gsm: Number of gsm_id to extract for each drug.
    
    Returns:
    - gsm_df: DataFrame containing rows with the specified gsm_id for the given drugs.
    """
    filtered_rows = []

    for drug in drugs_list:
        # Filter the DataFrame for the current drug
        drug_df = pnd_df[pnd_df['Drugs'] == drug]
        
        # Sort the DataFrame by PND in ascending order and select the top num_gsm rows
        top_gsm_df = drug_df.sort_values(by='PND').head(num_gsm)
        
        # Append the selected rows to the list
        filtered_rows.append(top_gsm_df)

    # Concatenate the selected rows into a single DataFrame
    gsm_df = pd.concat(filtered_rows)
    
    return gsm_df