Source code for aweSOM.sce

## Statistically Combined Ensemble (SCE) code for N-dimensional data
## See Ha et al. (2024) for implementation details
## For detailed mathematical description, see Bussov & Nattila (2021)
## https://github.com/mkruuse/segmenting-turbulent-simulations-with-ensemble-learning
## If GPU-enabled, capable of handling 1000^3 data with 1 A100-80GB GPU.

import os
import glob
import argparse
import numpy as np

# Use JAX if GPU and the jax package is installed, otherwise use NumPy
# added support for experimental JAX backend with METAL on Apple silicon
try:
    import jax

    default_device = jax.default_backend()
    if default_device == "gpu" or default_device == "METAL":
        USE_JAX = True
    else:
        USE_JAX = False
except:
    USE_JAX = False

# Define a unified interface for the functions
if USE_JAX:
    from jax import numpy as jnp
    from jax import jit

    print("Using JAX for GPU computation")
    array_lib = jnp  # Use JAX's numpy interface
else:
    print("Using NumPy for CPU computation")
    array_lib = np  # Use NumPy


# Define a conditional jit decorator
[docs] def conditional_jit(func): if USE_JAX: return jit(func) else: return func
# read array from clusterID.npy
[docs] def load_som_npy(path: str) -> array_lib.ndarray: return array_lib.load(path, "r")
[docs] @conditional_jit def create_mask(img: array_lib.ndarray, cid: int) -> array_lib.ndarray: """Create a mask for a given cluster id Args: img (jnp.ndarray): 3D array of cluster ids cid (int): cluster id to mask Returns: (j)np.ndarray: masked cluster, 1 where cluster id is cid, 0 elsewhere """ return array_lib.where(img == cid, 1, 0)
[docs] def compute_SQ(mask: array_lib.ndarray, maskC: array_lib.ndarray): """Compute the quality index between two masks Args: mask ((j)np.ndarray): mask of cluster C maskC ((j)np.ndarray): mask of cluster C' Returns: SQ (float): quality index, equals to S/Q SQ_matrix ((j)np.ndarray): pixelwise quality index, equals to S/Q * mask """ # -------------------------------------------------- # product of two masked arrays; corresponds to intersection I = array_lib.multiply(mask, maskC) # -------------------------------------------------- # sum of two masked arrays; corresponds to union U = array_lib.ceil((mask + maskC) * 0.5) # U_area = array_lib.sum(U) / (nx * ny * nz) # -------------------------------------------------- # Intersection signal strength of two masked arrays, S S = array_lib.sum(I) / array_lib.sum(U) # -------------------------------------------------- # Union quality of two masked arrays, Q if array_lib.max(mask) == 0 or array_lib.max(maskC) == 0: return 0.0, array_lib.zeros(mask.shape) Q = array_lib.sum(U) / (array_lib.sum(mask) + array_lib.sum(maskC)) - array_lib.sum( I ) / (array_lib.sum(mask) + array_lib.sum(maskC)) if Q == 0.0: return 0.0, array_lib.zeros( mask.shape ) # break here because this causes NaNs that accumulate. # -------------------------------------------------- # final measure for this comparison is (S/Q) x Union SQ = S / Q SQ_matrix = SQ * mask return SQ, SQ_matrix
[docs] def loop_over_all_clusters( all_files: list[str], number_of_clusters: array_lib.ndarray, dimensions: np.ndarray, subfolder: str = "SCE", ) -> int: """ Loops over all clusters in the given data, compute goodness-of-fit, then save Gsum values to file. Args: all_files (list[str]): A list of data files saved in '.npy' format. number_of_clusters ((j)np.ndarray): An array of the number of cluster ids in each run. dimensions (np.ndarray): A 1d array representing the dimensions of the clusters (can be any dimension but nx*ny*nz has to be equal to number of data points). subfolder (str): The name of the subfolder to save the results to. Returns: Save Gsum value of each cluster C to a file. """ pass runs = all_files # [file.strip('.npy') for file in all_files] # loop over data files reading image by image for i in range(len(runs)): run = runs[i] clusters_1d = load_som_npy(run) print("-----------------------") print("Run : ", run, flush=True) with open(subfolder + "/multimap_mappings.txt", "a") as f: f.write("{}\n".format(run.strip(".npy"))) # nx x ny x nz size maps # nz,ny,nx = array_lib.cbrt(clusters_1d.shape[0]).astype(int), array_lib.cbrt(clusters_1d.shape[0]).astype(int), array_lib.cbrt(clusters_1d.shape[0]).astype(int) # clusters = clusters_1d.reshape(nz,ny,nx) clusters = clusters_1d.reshape(dimensions) # unique ids nids = number_of_clusters[i] # number of cluster ids in this run # ids = np.arange(nids) print("nids : ", nids) for cid in range(nids): # print(' -----------------------') # print(' cid : ', cid, flush=True) # create masked array where only id == cid are visible mask = create_mask(clusters, cid) total_mask = array_lib.zeros(dimensions, dtype=float) total_SQ_scalar = 0.0 for j in range(len(runs)): runC = runs[j] if j == i: # don't compare to itself continue clustersC_1d = load_som_npy(runC) clustersC = clustersC_1d.reshape(dimensions) # print(' -----------------------') # print(' ',runC, flush=True) nidsC = number_of_clusters[j] # number of cluster ids in this run # print(' nidsC : ', nidsC) for cidC in range(nidsC): maskC = create_mask(clustersC, cidC) SQ, SQ_matrix = compute_SQ(mask, maskC) # pixelwise stacking of 2 masks total_mask += SQ_matrix # for numpy array total_SQ_scalar += SQ # save total mask to file # print("Saving total mask to file", flush=True) array_lib.save( subfolder + "/mask-{}-id{}.npy".format(run.strip(".npy"), cid), total_mask, ) # print("Saving total SQ scalar to multimap_mapping", flush=True) with open(subfolder + "/multimap_mappings.txt", "a") as f: f.write("{} {}\n".format(cid, total_SQ_scalar)) return 0
[docs] def find_number_of_clusters(cluster_files: list[str]) -> array_lib.ndarray: """ Find the number of clusters in each run. Args: cluster_files (list[str]): A list of data files saved in '.npy' format. Returns: number_of_clusters ((j)np.ndarray): An array of the number of cluster ids in each run. """ number_of_clusters = np.empty(len(cluster_files), dtype=int) for run in range(len(cluster_files)): clusters = load_som_npy(cluster_files[run]) ids = array_lib.unique(clusters) number_of_clusters[run] = len(ids) return number_of_clusters
[docs] def parse_args(): """argument parser for the sce.py script""" parser = argparse.ArgumentParser(description="SCE code") parser.add_argument( "--folder", type=str, dest="folder", default=os.getcwd(), help="Folder name" ) parser.add_argument( "--subfolder", type=str, dest="subfolder", default="SCE", help="Subfolder name" ) parser.add_argument( "--dims", nargs="+", action="store", type=int, dest="dims", default=[640, 640, 640], help="Dimensions of the data", ) return parser.parse_args()
if __name__ == "__main__": args = parse_args() print("Starting SCE", flush=True) folder = args.folder os.chdir(folder) cluster_files = glob.glob("*.npy") # -------------------------------------------------- # data subfolder = args.subfolder print(cluster_files) # -------------------------------------------------- # calculate unique number of clusters per run nids_array = find_number_of_clusters(cluster_files) print("nids_array:", nids_array, flush=True) print("There are {} runs".format(len(cluster_files)), flush=True) print("There are {} clusters in total".format(np.sum(nids_array)), flush=True) # -------------------------------------------------- # generate index for multimap_mapping as the loop runs. Avoid declaring a dict beforehand to avoid memory leaks try: # try to create subfolder, if it exists, pass os.mkdir(subfolder) except FileExistsError: pass with open(subfolder + "/multimap_mappings.txt", "w") as f: f.write("") # -------------------------------------------------- # make shape of the data data_dims = np.array(args.dims) # -------------------------------------------------- # loop over data files reading image by image and do pairwise comparisons # all wrapped inside the loop_over_all_clusters function, which uses JAX for fast computation loop_over_all_clusters(cluster_files, nids_array, data_dims, subfolder)