Source code for aweSOM.make_sce_clusters

import argparse
import os
import matplotlib.pyplot as plt
import shutil
import numpy as np
from scipy.signal import savgol_filter


[docs] def plot_gsum_values( gsum_values: list[float], minimas: list[int] = None, file_path: str = None ): """ Plot the gsum values with optional minima markers. Args: gsum_values (list[float]): A list of gsum values to plot. minimas (list[int], optional): A list of indices indicating the minima to highlight. Defaults to None. file_path (str, optional): The directory path where the plot will be saved. If None, the plot will be displayed. Defaults to None. Returns: None: This function does not return a value. It either displays the plot or saves it to a file. """ plt.figure(dpi=300) plt.plot( list(range(len(gsum_values))), gsum_values, marker="o", c="k", markersize=2, linewidth=1, ) if minimas is not None: plt.scatter( minimas, [gsum_values[i] for i in minimas], c="b", marker="x", label="Minimas", ) plt.legend() plt.title(f"Sorted gsum values") plt.xlabel("Ranked clusters") plt.ylabel("Gsum value") plt.grid() if file_path is None: plt.show() else: plt.savefig(f"{file_path}/gsum_values.png") print("Saved gsum values plot")
[docs] def plot_gsum_deriv( gsum_deriv: np.ndarray, threshold: float, minimas: list[int] = None, file_path: str = None, ): """ Plots the gsum derivative with optional minima highlighted. Args: gsum_deriv (np.ndarray): An array of gsum derivative values to be plotted. threshold (float): The threshold value to draw a horizontal line on the plot. minimas (list[int], optional): A list of indices representing the minima to be highlighted on the plot. Defaults to None. file_path (str, optional): The file path where the plot will be saved. If None, the plot will be displayed instead. Defaults to None. Returns: None: This function does not return a value. It either displays the plot or saves it to a file. """ x_range = list(range(len(gsum_deriv))) plt.figure(dpi=300) print("minimas", minimas, flush=True) plt.plot(x_range, gsum_deriv, marker="o", c="k", markersize=2, linewidth=1) if minimas is not None: plt.scatter(minimas, [gsum_deriv[i] for i in minimas], c="b", marker="x") plt.ylim(threshold * 5, 0.0) plt.title(f"Sorted gsum derivatives") plt.xlabel("Ranked clusters") plt.ylabel("Gsum derivative") plt.grid() plt.hlines(threshold, 0, x_range[-1], colors="r", linestyles="--") if file_path is None: plt.show() else: plt.savefig(f"{file_path}/gsum_deriv.png") print("Saved gsum derivative plot")
[docs] def get_gsum_values(mapping_file: str): """Get the gsum values from the mapping file Args: mapping_file (str): path to the mapping file Returns: list: gsum values dict: mapping of gsum values to cluster id and cluster name """ mapping = dict() with open(mapping_file, "r") as f: for line in f: line = line.strip("\n") if "-" in line: key_name = line mapping[key_name] = [] else: mapping[key_name].append(line.split(" ")) map_list = [] for key in mapping.keys(): map_list.extend([[float(i[1]), int(i[0]), key] for i in mapping[key]]) map_list.sort(key=lambda map_list: map_list[0], reverse=True) gsum_values = [map_list[i][0] for i in range(len(map_list))] return gsum_values, map_list
[docs] def get_sce_cluster_separation(gsum_deriv: np.ndarray, threshold: float): """ Identify the separation of clusters in a given derivative array based on a specified threshold. Args: gsum_deriv (np.ndarray): A 1D array representing the derivative values. threshold (float): The threshold value used to determine cluster separation. Returns: tuple: A tuple containing: - list: A list of ranges for the identified clusters, where each range is represented as a list of two integers. - list: A list of indices representing the local minima found below the threshold. """ threshold_crossed = False # True if gsum_deriv[0] < threshold else False minimas = [] for i in range(1, len(gsum_deriv) - 1): if ( (gsum_deriv[i] < threshold) & (gsum_deriv[i] < gsum_deriv[i - 1]) & (gsum_deriv[i] < gsum_deriv[i + 1]) & (threshold_crossed == True) ): minimas.append(i) threshold_crossed = False if (gsum_deriv[i] > threshold) & (threshold_crossed == False): threshold_crossed = True minimas.pop( 0 ) # remove the first minimum because it is usually part of the first cluster # from the local minima, find the ranges of the clusters cluster_ranges = [] for i in range(len(minimas) - 1): if i == 0: cluster_ranges.append([0, minimas[i]]) cluster_ranges.append([minimas[i], minimas[i + 1]]) if i == len(minimas) - 2: cluster_ranges.append([minimas[i + 1], len(gsum_deriv)]) return cluster_ranges, minimas
[docs] def combine_separated_clusters( map_list: list, cluster_ranges: list[list[int]], dims: int, file_path: str ) -> np.ndarray: """ Combine separated clusters by summing their corresponding gsum masks. Args: map_list (list): A list of instances representing the binary maps. cluster_ranges (list[list[int]]): A list of ranges indicating the start and end indices for each cluster. dims (int): The dimensions of the binary maps. file_path (str): The file path where the binary maps are stored. Returns: np.ndarray: A numpy array containing the summed binary maps for each cluster. """ remapped_clusters = dict() for i in range(len(cluster_ranges)): start_pointer, end_pointer = cluster_ranges[i] remapped_clusters[i] = [] for j in range(start_pointer, end_pointer): remapped_clusters[i].append(map_list[j]) print( "Length of remapped clusters : ", [len(remapped_clusters[k]) for k in remapped_clusters.keys()], flush=True, ) # Add values of the binary map of each cluster to obtain a new map # read in each binary map within a cluster_range, then sum them up all_signals_map = np.empty(([len(remapped_clusters)] + dims), dtype=np.float32) for cluster in remapped_clusters.keys(): print("Currently analyzing cluster : ", cluster, flush=True) print( "Number of instances in cluster : ", len(remapped_clusters[cluster]), flush=True, ) # cannot use jax here because it uses too much memory; cannot use numba because it does not support np.load; loading all binary maps in each cluster at once will use more memory, but is also ~30% faster than loading them sequentially and adding to total every step. this_cluster_signal_map = np.zeros( ([len(remapped_clusters[cluster])] + dims), dtype=np.float32 ) for i, instance in enumerate(remapped_clusters[cluster]): if i % 10 == 0: print("Instance", i, flush=True) this_cluster_signal_map[i] = np.reshape( np.load(file_path + f"/mask-{instance[2]}-id{instance[1]}.npy"), newshape=dims, ) all_signals_map[cluster] = np.sum(this_cluster_signal_map, axis=0) return all_signals_map
[docs] def make_file_name(n: int, ext: str) -> str: """Make a filename based on the number and the extension given. Args: n (int): number to be converted to a filename ext (str): file extension Returns: str: filename """ if n < 10: file_n = "000" + str(n) elif (n >= 10) & (n < 100): file_n = "00" + str(n) else: file_n = "0" + str(n) return f"{file_n}.{ext}"
[docs] def parse_args(): """argument parser for the make_sce_clusters.py script""" parser = argparse.ArgumentParser( description="Use multimap mapping to analyze and segment groups of features" ) parser.add_argument( "--file_path", type=str, dest="file_path", default=os.getcwd(), help="Multimap mapping file path", ) parser.add_argument( "--copy_clusters", dest="copy_clusters", action="store_true", help="Copy the clusters to a new folder", ) parser.add_argument( "--threshold", type=float, dest="threshold", default=-0.015, help="Threshold for the derivative of the gsum values", ) parser.add_argument( "--return_gsum", dest="return_gsum", action="store_true", help="Return the sorted gsum values plot", ) parser.add_argument( "--dims", nargs="+", action="store", type=int, dest="dims", default=[640, 640, 640], help="Dimensions of the data", ) parser.add_argument( "--save_combined_map", dest="save_combined_map", action="store_true", help="Save the combined map of all clusters", ) return parser.parse_args()
if __name__ == "__main__": args = parse_args() gsum_values, map_list = get_gsum_values(args.file_path + "/multimap_mappings.txt") print("Length of sorted map", len(gsum_values), flush=True) # now iterate through the list and copy the files to the appropriate cluster folder if args.copy_clusters: ranked_clusters_dir = os.path.join(args.file_path, "ranked-clusters") if not os.path.exists(ranked_clusters_dir): os.makedirs(ranked_clusters_dir) for i in range(len(gsum_values)): origin_file_name = "{}/mask-{}_id{}.png".format( args.file_path, map_list[i][2], map_list[i][1] ) destination_file_name = "{}/ranked-clusters/{}".format( args.file_path, make_file_name(i) ) shutil.copyfile(origin_file_name, destination_file_name) print("Done copying files") # apply a Savitzky-Golay filter to smooth the gsum values smooth_fraction = 10 order = 4 smoothed_map = gsum_values.copy() print("Applying Savitzky-Golay filter") smoothed_map = savgol_filter( smoothed_map, len(gsum_values) // smooth_fraction, order, deriv=0 ) # compute the derivative of the gsum values to find the drop gsum_deriv = ( savgol_filter(smoothed_map, len(gsum_values) // smooth_fraction, order, deriv=1) / smoothed_map ) # iterate through the derivative and find the local minima threshold = args.threshold cluster_ranges, minimas = get_sce_cluster_separation(gsum_deriv, threshold) print("Minimas", minimas, flush=True) print("Cluster ranges", cluster_ranges, flush=True) print("Number of clusters", len(cluster_ranges), flush=True) # plot the gsum and gsum_deriv values if args.return_gsum: plot_gsum_values(gsum_values, minimas, args.file_path) plot_gsum_deriv(gsum_deriv, threshold, minimas, args.file_path) # save the separated SCE clusters if args.save_combined_map: combined_sce_clusters = combine_separated_clusters( map_list, cluster_ranges, args.dims, args.file_path ) # save the new binary map np.save( args.file_path + f"/sce_clusters_{threshold}.npy", combined_sce_clusters ) print( f"Saved new combined clusters as {args.file_path}/sce_clusters_{threshold}.npy" )