utilities.py 3.81 KB
Newer Older
ADRIAN  AYUSO MUNOZ's avatar
ADRIAN AYUSO MUNOZ committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113
import numpy as np
import pandas as pd
import sklearn.metrics as metrics
import matplotlib.pyplot as plt


def plot_roc(y_test, preds, edge, folder='', extension=''):
    fpr, tpr, threshold = metrics.roc_curve(y_test, preds)
    roc_auc = metrics.auc(fpr, tpr)    
    
    label = " ".join(edge) + " " + 'AUC = %0.2f' % roc_auc
    fileName = 'plots/' + folder + 'metrics/aucroc' + extension+'.svg'
    random = [[0, 1], [0, 1], 'r--']
    
    plotAndSaveFig(title='ROC Curve', x=fpr, y=tpr, label=label, path=fileName,
                   loc='lower right', xlim=[0, 1], ylim=[0, 1], xlabel='False Positive Rate', ylabel='True Positive Rate', random=random)

    print("AUC:", roc_auc)
    return fpr, tpr, roc_auc


def plot_prc(y_true, y_pred, edge,folder='', extension=''):
    precision, recall, trhesholds = metrics.precision_recall_curve(y_true, y_pred)
    average_precision = metrics.average_precision_score(y_true, y_pred)
    
    label = " ".join(edge) + " " + 'RPC = %0.2f' % average_precision
    fileName = 'plots/' + folder + 'metrics/prc'+extension+'.svg'
    
    plotAndSaveFig(title='Precision-Recall Curve', x=recall, y=precision, label=label, path=fileName,
                   loc='lower right', xlim=[0, 1], ylim=[0, 1], xlabel='Recall', ylabel='Precision')
 
    print("PRC:", average_precision)
    return recall, precision, average_precision


def plot_dist(folder='', extension=''):
    preds = [pd.read_csv('results/dis_dru_the_5013_table.csv'), pd.read_csv('results/dis_dru_the_5013R_table.csv')]
    df = pd.DataFrame(data={'RepoDB': preds[0]['pred'], 'Random': preds[1]['pred']})

    ax = df.plot.hist(bins=50, alpha=0.5)
    ax.set_xticks(np.arange(0, 1.1, 0.1))
    ax.set_title('RepoDB & Random Prediction Histogram')
    ax.set_xlabel('Prediction Score')
    ax.figure.savefig('plots/' + folder + 'results/histogram'+extension+'.svg', format='svg', dpi=1200)
    ax.figure.clf()
    
    bx = df.plot.kde()
    bx.set_title('RepoDB & Random Prediction')
    bx.set_xlabel('Prediction Score')
    bx.figure.savefig('plots/' + folder + 'results/distribution'+extension+'.svg', format='svg', dpi=1200)
    bx.set_xlim([0, 1])
    bx.figure.savefig('plots/' + folder + 'results/distribution01'+extension+'.svg', format='svg', dpi=1200)
    bx.figure.clf()


def plotAndSaveFig(title, x, y, label, path, loc=None, xlim=None, ylim=None, xlabel=None, ylabel=None, random=None):
    plt.title(title)
    plt.plot(x, y, label=label)
    
    if loc is not None:
        plt.legend(loc=loc)
        
    if xlim is not None:
        plt.xlim(xlim)
        
    if ylim is not None:
        plt.ylim(ylim)
        
    if xlabel is not None:
        plt.xlabel(xlabel)
        
    if ylabel is not None:
        plt.ylabel(ylabel)
        
    if random is not None:
        plt.plot(random[0], random[1], random[2])
        
    plt.show()
    plt.savefig(path, format='svg', dpi=1200)
    plt.clf()


def plotTogether(title, x, y, label, path, loc=None, xlim=None, ylim=None, xlabel=None, ylabel=None, random=None):
    figH, axs = plt.subplots(2, figsize=(6, 10))
    figV, axs2 = plt.subplots(1, 2, figsize=(12, 4))
    
    for i, elem in enumerate(axs):
        elem.plot(x[i], y[i], label=label[i])
        elem.set_title(title[i])
    
        if loc[i] is not None:
            elem.legend(loc=loc[i])
    
        if random[i] is not None:
            elem.plot(random[i][0], random[i][1], random[i][2])
        
        if xlim[i] is not None:
            elem.set_xlim(xlim[i])
    
        if ylim[i] is not None:
            elem.set_ylim(ylim[i])
        
        if xlabel[i] is not None:
            elem.set_xlabel(xlabel)
           
        if ylabel[i] is not None:
            elem.set_ylabel(ylabel)

    axs2[0] = axs[0]
    axs2[1] = axs[1]
       
    figH.savefig(path[0], format='svg', dpi=1200)
    figV.savefig(path[1], format='svg', dpi=1200)