diff --git a/explicability/print_shapes.py b/explicability/print_shapes.py new file mode 100644 index 0000000000000000000000000000000000000000..91a6342b54d4c774b5a50ae8c7c57b5adac29bbd --- /dev/null +++ b/explicability/print_shapes.py @@ -0,0 +1,15 @@ +import os +import numpy as np + +def print_shapes(directory): + # List all files in the directory + for file in os.listdir(directory): + if file.endswith(".npy"): + # Load the .npy file + data = np.load(os.path.join(directory, file)) + # Print the shape of the numpy array + print(f"{file}: {data.shape}") + +# Example usage: +directory_path = './output/shap_inter_values' +print_shapes(directory_path) \ No newline at end of file