diff --git a/code/shap/plotSHAP.py b/code/shap/plotSHAP.py new file mode 100644 index 0000000000000000000000000000000000000000..f3f2802cf0cfdd1a4bd347c439bb58ede8531afc --- /dev/null +++ b/code/shap/plotSHAP.py @@ -0,0 +1,96 @@ +import pandas as pd +import numpy as np +import shap + +import matplotlib.pyplot as plt + +from os import listdir + +from sklearn.model_selection import train_test_split + + +def getDatasets(dataset, f): + # Import of database + label = pd.read_csv("labels.csv") + if dataset == "Dropout_1": + if f == "_Filtered": + db_cluster = pd.read_csv("featsGR_cluster.csv", sep=",") + db = pd.read_csv("featsGR.csv", sep=",") + else: + db_cluster = pd.read_csv("featsCluster.csv", sep=",").drop(columns="Unnamed: 0") + db = pd.read_csv("feats.csv", sep=",").drop(columns="Unnamed: 0") + + + # Creation of train and test sets for the set without cluster + columns_to_be_changed = db.select_dtypes(exclude='number').columns.values + sin_cluster_data_features = pd.get_dummies(db, columns=columns_to_be_changed) + + # Creation of train and test sets for the set with cluster + columns_to_be_changed = db_cluster.select_dtypes(exclude='number').columns.values + cluster_data_features = pd.get_dummies(db_cluster, columns=columns_to_be_changed) + + for col1 in sin_cluster_data_features: + sin_cluster_data_features[col1] = sin_cluster_data_features[col1].astype(float) + for col2 in cluster_data_features: + cluster_data_features[col2] = cluster_data_features[col2].astype(float) + + + return sin_cluster_data_features, cluster_data_features, label + +def plots(shap_values, tFeatures, name): + + print(shap_values.shape) + print(tFeatures.shape) + print(name) + + shap.summary_plot(shap_values, tFeatures, plot_type="bar", show=False, max_display=10, plot_size=(20, 8)) + plt.savefig('figures/'+name+'_bar.svg', format='svg', dpi=1200) + plt.clf() + plt.xscale('log') + shap.summary_plot(shap_values, tFeatures, plot_type="dot", show=False, max_display=10, plot_size=(20, 8)) + plt.savefig('figures/'+name+'_dot.svg', format='svg', dpi=1200) + plt.clf() + print(":::::::::::::::::::::::::::::::::::::::::::::::::::") + +if __name__ == "__main__": + + datasets = ["Dropout_1"] + filtered = ["","_Filtered"] + for f in filtered: + for d in datasets: + sin_cluster_data_features, cluster_data_features, label = getDatasets(d, f) + + shap.initjs() + + train_data_features, test_data_features, train_data_label, test_data_label = train_test_split(sin_cluster_data_features, + label, + test_size=0.2, + random_state=25) + train_data_features_cluster, test_data_features_cluster, train_data_label_cluster, test_data_label_cluster = train_test_split( + cluster_data_features, label, test_size=0.2, random_state=25) + + features = list(train_data_features.columns.values) # beware that this will change in case of FSS + featuresC = list(train_data_features_cluster.columns.values) + + for file in listdir("shapValues/" + d + f): + shapValue = np.load("shapValues/" + d + f + '/' + file) + nameO = file.split(".")[0] + + cluster = nameO[-1] + tree = nameO[10] == 'T' or nameO[10] == 'R' + if cluster == 'C': + test_data_feats = test_data_features_cluster[:500] + feats = featuresC + else: + test_data_feats = test_data_features[:500] + feats = features + + if tree: + shape = [2, 500, len(feats)] + else: + shape = [500, len(feats)] + + shapValue = np.reshape(shapValue, shape) + if tree: + shapValue = shapValue[0] + plots(shapValue, test_data_feats, nameO)