plotSHAP.py 3.91 KB
Newer Older
Lucia Prieto's avatar
Lucia Prieto committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
import pandas as pd
import numpy as np
import shap

import matplotlib.pyplot as plt

from os import listdir

from sklearn.model_selection import train_test_split


def getDatasets(dataset, f):
    # Import of database
    label = pd.read_csv("labels.csv")
    if dataset == "Dropout_1":
        if f == "_Filtered":
            db_cluster = pd.read_csv("featsGR_cluster.csv", sep=",")
            db = pd.read_csv("featsGR.csv", sep=",")
        else:
            db_cluster = pd.read_csv("featsCluster.csv", sep=",").drop(columns="Unnamed: 0")
            db = pd.read_csv("feats.csv", sep=",").drop(columns="Unnamed: 0")


    # Creation of train and test sets for the set without cluster
    columns_to_be_changed = db.select_dtypes(exclude='number').columns.values
    sin_cluster_data_features = pd.get_dummies(db, columns=columns_to_be_changed)

    # Creation of train and test sets for the set with cluster
    columns_to_be_changed = db_cluster.select_dtypes(exclude='number').columns.values
    cluster_data_features = pd.get_dummies(db_cluster, columns=columns_to_be_changed)

    for col1 in sin_cluster_data_features:
        sin_cluster_data_features[col1] = sin_cluster_data_features[col1].astype(float)
    for col2 in cluster_data_features:
        cluster_data_features[col2] = cluster_data_features[col2].astype(float)


    return sin_cluster_data_features, cluster_data_features, label

def plots(shap_values, tFeatures, name):

    print(shap_values.shape)
    print(tFeatures.shape)
    print(name)

    shap.summary_plot(shap_values, tFeatures, plot_type="bar", show=False, max_display=10, plot_size=(20, 8))
    plt.savefig('figures/'+name+'_bar.svg', format='svg', dpi=1200)
    plt.clf()
    plt.xscale('log')
    shap.summary_plot(shap_values, tFeatures, plot_type="dot", show=False,  max_display=10, plot_size=(20, 8))
    plt.savefig('figures/'+name+'_dot.svg', format='svg', dpi=1200)
    plt.clf()
    print(":::::::::::::::::::::::::::::::::::::::::::::::::::")

if __name__ == "__main__":

    datasets = ["Dropout_1"]
    filtered = ["","_Filtered"]
    for f in filtered:
        for d in datasets:
            sin_cluster_data_features, cluster_data_features, label = getDatasets(d, f)

            shap.initjs()

            train_data_features, test_data_features, train_data_label, test_data_label = train_test_split(sin_cluster_data_features,
                                                                                                          label,
                                                                                                          test_size=0.2,
                                                                                                          random_state=25)
            train_data_features_cluster, test_data_features_cluster, train_data_label_cluster, test_data_label_cluster = train_test_split(
                cluster_data_features, label, test_size=0.2, random_state=25)

            features = list(train_data_features.columns.values)  # beware that this will change in case of FSS
            featuresC = list(train_data_features_cluster.columns.values)

            for file in listdir("shapValues/" + d + f):
                shapValue = np.load("shapValues/" + d + f + '/' + file)
                nameO = file.split(".")[0]

                cluster = nameO[-1]
                tree = nameO[10] == 'T' or nameO[10] == 'R'
                if cluster == 'C':
                    test_data_feats = test_data_features_cluster[:500]
                    feats = featuresC
                else:
                    test_data_feats = test_data_features[:500]
                    feats = features

                if tree:
                    shape = [2, 500, len(feats)]
                else:
                    shape = [500, len(feats)]

                shapValue = np.reshape(shapValue, shape)
                if tree:
                    shapValue = shapValue[0]
                plots(shapValue, test_data_feats, nameO)