compute_shap_inter_vals.py 3.94 KB
Newer Older
Joaquin Torres's avatar
Joaquin Torres committed
1 2 3 4 5 6
# Computing SHAP Interaction Values
# Author: Joaquín Torres Bravo
"""
    Script to compute SHAP interaction values for chosen models.
"""

7 8 9 10
# Libraries
# --------------------------------------------------------------------------------------------------------
import pandas as pd
import numpy as np
Joaquin Torres's avatar
Joaquin Torres committed
11 12 13
import shap # Explainability
import pickle # Loading/saving models
# Models
14 15 16 17 18 19 20 21 22 23 24 25
from xgboost import XGBClassifier
from sklearn.ensemble import RandomForestClassifier, BaggingClassifier, AdaBoostClassifier
from sklearn.neural_network import MLPClassifier
from sklearn.svm import SVC
from sklearn.linear_model import  LogisticRegression
from sklearn.tree import DecisionTreeClassifier
# --------------------------------------------------------------------------------------------------------

# Reading test data
# --------------------------------------------------------------------------------------------------------
def read_test_data(attribute_names):
    # Load test data
Joaquin Torres's avatar
Joaquin Torres committed
26 27 28 29
    X_test_pre = np.load('../gen_train_data/output/pre/X_test_pre.npy', allow_pickle=True)
    y_test_pre = np.load('../gen_train_data/output/pre/y_test_pre.npy', allow_pickle=True)
    X_test_post = np.load('../gen_train_data/output/post/X_test_post.npy', allow_pickle=True)
    y_test_post = np.load('../gen_train_data/output/post/y_test_post.npy', allow_pickle=True)
30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45

    # Type conversion needed    
    data_dic = {
        "X_test_pre": pd.DataFrame(X_test_pre, columns=attribute_names).convert_dtypes(),
        "y_test_pre": y_test_pre,
        "X_test_post": pd.DataFrame(X_test_post, columns=attribute_names).convert_dtypes(),
        "y_test_post": y_test_post,
    }
    return data_dic
# --------------------------------------------------------------------------------------------------------

if __name__ == "__main__":

    # Setup
    # --------------------------------------------------------------------------------------------------------
    # Retrieve attribute names in order
Joaquin Torres's avatar
Joaquin Torres committed
46
    attribute_names = list(np.load('../EDA/output/feature_names/all_features.npy', allow_pickle=True))
47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
    # Reading data
    data_dic = read_test_data(attribute_names)
    method_names = {
        0: "ORIG",
        1: "ORIG_CW",
        2: "OVER",
        3: "UNDER"
    }
    model_choices = {
        "ORIG": "XGB",
        "ORIG_CW": "RF",
        "OVER": "XGB",
        "UNDER": "XGB"
    }
    # --------------------------------------------------------------------------------------------------------

    # Shap value generation
    # --------------------------------------------------------------------------------------------------------
    for i, group in enumerate(['pre', 'post']):
        # Get test dataset based on group, add column names
        X_test = data_dic['X_test_' + group]
        y_test = data_dic['y_test_' + group]
        for j, method in enumerate(['', '', 'over_', 'under_']):
            print(f"{group}-{method_names[j]}")
            method_name = method_names[j]
            model_name = model_choices[method_name]
73
            model_path = f"../model_selection/output/fitted_models/{group}_{method_names[j]}_{model_name}.pkl"
74 75 76 77 78 79 80 81
            # Load the fitted model from disk
            with open(model_path, 'rb') as file:
                fitted_model = pickle.load(file)
            # Check if we are dealing with a tree vs nn model
            is_tree = model_name not in ['LR', 'SVM', 'MLP']
            if is_tree:
                 explainer = shap.TreeExplainer(fitted_model)
            # else:
Joaquin Torres's avatar
Joaquin Torres committed
82
            #     explainer = shap.KernelExplainer...
83 84 85 86 87 88
            # Compute shap values
            shap_interaction_values = explainer.shap_interaction_values(X_test)
            # ---------------------------------------------------------------------------------------------------------
            # Save results
            np.save(f"./output/shap_inter_values/{group}_{method_names[j]}", shap_interaction_values)
    # --------------------------------------------------------------------------------------------------------