print_shapes.py 445 Bytes
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
import os
import numpy as np

def print_shapes(directory):
    # List all files in the directory
    for file in os.listdir(directory):
        if file.endswith(".npy"):
            # Load the .npy file
            data = np.load(os.path.join(directory, file))
            # Print the shape of the numpy array
            print(f"{file}: {data.shape}")

# Example usage:
directory_path = './output/shap_inter_values'
print_shapes(directory_path)