shap_plots.py 2.77 KB
Newer Older
1 2 3 4 5
# Libraries
# --------------------------------------------------------------------------------------------------------
import pandas as pd
import numpy as np
import shap
6
import matplotlib.pyplot as plt
7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40
# --------------------------------------------------------------------------------------------------------

# Reading test data
# --------------------------------------------------------------------------------------------------------
def read_test_data(attribute_names):
    # Load test data
    X_test_pre = np.load('../gen_train_data/data/output/pre/X_test_pre.npy', allow_pickle=True)
    y_test_pre = np.load('../gen_train_data/data/output/pre/y_test_pre.npy', allow_pickle=True)
    X_test_post = np.load('../gen_train_data/data/output/post/X_test_post.npy', allow_pickle=True)
    y_test_post = np.load('../gen_train_data/data/output/post/y_test_post.npy', allow_pickle=True)

    # Type conversion needed    
    data_dic = {
        "X_test_pre": pd.DataFrame(X_test_pre, columns=attribute_names).convert_dtypes(),
        "y_test_pre": y_test_pre,
        "X_test_post": pd.DataFrame(X_test_post, columns=attribute_names).convert_dtypes(),
        "y_test_post": y_test_post,
    }
    return data_dic
# --------------------------------------------------------------------------------------------------------

if __name__ == "__main__":
    # Setup
    # --------------------------------------------------------------------------------------------------------
    # Retrieve attribute names in order
    attribute_names = list(np.load('../gen_train_data/data/output/attributes.npy', allow_pickle=True))
    # Reading data
    data_dic = read_test_data(attribute_names)
    method_names = {
        0: "ORIG",
        1: "ORIG_CW",
        2: "OVER",
        3: "UNDER"
    }
41 42 43 44 45 46
    model_choices = {
        "ORIG": "XGB",
        "ORIG_CW": "RF",
        "OVER": "XGB",
        "UNDER": "XGB"
    }
47 48 49 50
    # --------------------------------------------------------------------------------------------------------

    # Plot generation
    # --------------------------------------------------------------------------------------------------------
51 52 53 54 55 56
    for j, method in enumerate(['', '', 'over_', 'under_']):
        if method != 'over_':
            continue
        for i, group in enumerate(['pre', 'post']):
            X_test = data_dic['X_test_' + group]
            y_test = data_dic['y_test_' + group]
57 58
            print(f"{group}-{method_names[j]}")
            method_name = method_names[j]
59
            model_name = model_choices[method_name]
60
            shap_inter_vals = np.load(f'./output/shap_inter_values/{group}_{method_name}.npy')
61 62
            shap.summary_plot(shap_inter_vals, X_test, max_display=15)
            plt.savefig(f'./output/plots/inters/{group}_{method_name}_{model_name}.svg', format='svg', dpi=1250)