cv_metrics_distr.py 2.03 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39
"""
    Plotting the distribution of the metrics obtained from cv via boxplots.
"""

# Libraries
# --------------------------------------------------------------------------------------------------------
import pandas as pd
import matplotlib.pyplot as plt  # Corrected import
# --------------------------------------------------------------------------------------------------------

if __name__ == "__main__":

    metric_names = ['F1', 'PREC', 'REC', 'ACC', 'NREC', 'TN', 'FN', 'FP', 'TP', 'AUROC', 'AUPRC']
    model_names_simple = ['DT', 'RF', 'Bagging', 'AB', 'XGB', 'LR', 'SVM', 'MLP']
    model_names_cs = ['DT', 'RF', 'Bagging', 'AB', 'LR', 'SVM']

    # Distribution of cv metrics
    # --------------------------------------------------------------------------------------------------------
    for group in ['pre', 'post']:
        for method in ['_ORIG', '_ORIG_CW', '_OVER', '_UNDER']:
            # Read current sheet as df
            df = pd.read_excel('./output_cv_metrics/metrics.xlsx', sheet_name=group+method)
            # Model names based on cost-senstive training or not
            if method == '_ORIG_CW':
                model_names = model_names_cs
            else:
                model_names = model_names_simple
            # Create figure for current sheet, one row per metric
            fig, axes = plt.subplots(len(metric_names), 1, figsize=(15, 8 * len(metric_names)))
            for metric_id, metric_name in enumerate(metric_names):
                # Get the axis for the current metric
                ax = axes[metric_id]
                for model_name in model_names:
                    row_name = f'{model_name}_{metric_name}'
                    # Collect data for the current model's metric
                    metric_row = df.loc[df['Unnamed: 0'] == row_name].iloc[0, 1:].values
                    if group == 'pre' and method == '_ORIG' and metric_id == 0 and model_name == 'DT':
                        print(metric_row)
    # --------------------------------------------------------------------------------------------------------