utilities.py 5.04 KB
Newer Older
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166
import numpy as np
import pandas as pd
import sklearn.metrics as metrics
import matplotlib.pyplot as plt


"""
Calculate and plot ROC curve.
Input:
    y_test: Real labels.
    preds: Predictions.
    edge: Edge to study.
    extension: Extension to the figure file.
Output:
    False positive and true positive rate and the area under the curve.
"""
def plot_roc(y_test, preds, edge, extension=''):
    fpr, tpr, threshold = metrics.roc_curve(y_test, preds)
    roc_auc = metrics.auc(fpr, tpr)    
    
    label = " ".join(edge) + " " + 'AUC = %0.2f' % roc_auc
    fileName = 'metrics/aucroc' + extension+'.svg'
    random = [[0, 1], [0, 1], 'r--']
    
    plotAndSaveFig(title='ROC Curve', x=fpr, y=tpr, label=label, path=fileName,
                   loc='lower right', xlim=[0, 1], ylim=[0, 1], xlabel='False Positive Rate',
                   ylabel='True Positive Rate', random=random)

    print("AUC:", roc_auc)
    return fpr, tpr, roc_auc


"""
Calculate and plot PR curve.
Input:
    y_test: Real labels.
    preds: Predictions.
    edge: Edge to study.
    extension: Extension to the figure file.
Output:
    Recall, precision and the area under the curve.
"""
def plot_prc(y_true, y_pred, edge, extension=''):
    precision, recall, trhesholds = metrics.precision_recall_curve(y_true, y_pred)
    average_precision = metrics.average_precision_score(y_true, y_pred)
    
    label = " ".join(edge) + " " + 'RPC = %0.2f' % average_precision
    fileName = 'metrics/prc'+extension+'.svg'
    
    plotAndSaveFig(title='Precision-Recall Curve', x=recall, y=precision, label=label, path=fileName,
                   loc='lower right', xlim=[0, 1], ylim=[0, 1], xlabel='Recall', ylabel='Precision')
 
    print("PRC:", average_precision)
    return recall, precision, average_precision

"""
Plot distribution of predictions of the true and random edges. Plots the distribution function and histogram.
Input:
    extension: Extension to the figure file.
"""
def plot_dist(extension=''):
    preds = [pd.read_csv('results/dis_dru_the_5013_table.csv'), pd.read_csv('results/dis_dru_the_5013R_table.csv')]
    df = pd.DataFrame(data={'RepoDB': preds[0]['pred'], 'Random': preds[1]['pred']})

    ax = df.plot.hist(bins=50, alpha=0.5)
    ax.set_xticks(np.arange(0, 1.1, 0.1))
    ax.set_title('RepoDB & Random Prediction Histogram')
    ax.set_xlabel('Prediction Score')
    ax.figure.savefig('results/histogram'+extension+'.svg', format='svg', dpi=1200)
    ax.figure.clf()
    
    bx = df.plot.kde()
    bx.set_title('RepoDB & Random Prediction')
    bx.set_xlabel('Prediction Score')
    bx.figure.savefig('results/distribution'+extension+'.svg', format='svg', dpi=1200)
    bx.set_xlim([0, 1])
    bx.figure.savefig('results/distribution01'+extension+'.svg', format='svg', dpi=1200)
    bx.figure.clf()

"""
It plots and saves the figure sent.
Input:
    title: Title of the plot.
    x: X axis.
    y: Y axis.
    label: Label of the plot.
    path: Path to save the figure.
    loc: Location of the legend.
    xlim: Limit of the X axis.
    ylim: Limit of the Y axis.
    xlabel: Label of the X axis.
    ylabel: Label of the Y axis.
    random: Random selector.
"""
def plotAndSaveFig(title, x, y, label, path, loc=None, xlim=None, ylim=None, xlabel=None, ylabel=None, random=None):
    plt.title(title)
    plt.plot(x, y, label=label)
    
    if loc is not None:
        plt.legend(loc=loc)
        
    if xlim is not None:
        plt.xlim(xlim)
        
    if ylim is not None:
        plt.ylim(ylim)
        
    if xlabel is not None:
        plt.xlabel(xlabel)
        
    if ylabel is not None:
        plt.ylabel(ylabel)
        
    if random is not None:
        plt.plot(random[0], random[1], random[2])
        
    plt.show()
    plt.savefig(path, format='svg', dpi=1200)
    plt.clf()

"""
It plots and saves together the figures sent.
Input:
    title: Title of the plot.
    x: X axis.
    y: Y axis.
    label: Label of the plot.
    path: Path to save the figure.
    loc: Location of the legend.
    xlim: Limit of the X axis.
    ylim: Limit of the Y axis.
    xlabel: Label of the X axis.
    ylabel: Label of the Y axis.
    random: Random selector.
"""
def plotTogether(title, x, y, label, path, loc=None, xlim=None, ylim=None, xlabel=None, ylabel=None, random=None):
    figH, axs = plt.subplots(2, figsize=(6, 10))
    figV, axs2 = plt.subplots(1, 2, figsize=(12, 4))
    
    for i, elem in enumerate(axs):
        elem.plot(x[i], y[i], label=label[i])
        elem.set_title(title[i])
    
        if loc[i] is not None:
            elem.legend(loc=loc[i])
    
        if random[i] is not None:
            elem.plot(random[i][0], random[i][1], random[i][2])
        
        if xlim[i] is not None:
            elem.set_xlim(xlim[i])
    
        if ylim[i] is not None:
            elem.set_ylim(ylim[i])
        
        if xlabel[i] is not None:
            elem.set_xlabel(xlabel)
           
        if ylabel[i] is not None:
            elem.set_ylabel(ylabel)

    axs2[0] = axs[0]
    axs2[1] = axs[1]
       
    figH.savefig(path[0], format='svg', dpi=1200)
    figV.savefig(path[1], format='svg', dpi=1200)