Source code for image_analysis_3D.segmentation_utils.cell_segmentation

"""
# Cell segmentation in 3D
"""

from typing import Union

import cupy
import cupyx
import cupyx.scipy.ndimage
import numpy as np
import scipy
import skimage.filters
import skimage.measure
import skimage.morphology
import skimage.segmentation
from skimage.filters import sobel


[docs] def fill_holes_in_mask( mask: np.ndarray, compartment: Union[str, None] = None, ) -> np.ndarray: """ This function fills holes in instance segmented mask images Parameters ---------- mask : np.ndarray 3D instance segmented mask image where each cell has a unique integer label and background is 0 compartment : str, optional Compartment type of the mask (e.g. "cell" or "organoid"), by default None. This is used to determine the hole filling strategy. Errors ------ ValueError If compartment is not specified, a ValueError is raised since the hole filling strategy depends on the compartment type. Returns ------- np.ndarray 3D instance segmented mask image with holes filled """ if compartment is None: raise ValueError("Compartment must be specified for hole filling.") # fill enclosed holes in segmented cells for each label # use hybrid filling: 3D fill + per-slice 2D fill (handles z-tunnels common in microscopy) mask_cp = cupy.asarray(mask) new_mask_cp = cupy.zeros_like(mask_cp) structure_3d = cupyx.scipy.ndimage.generate_binary_structure(rank=3, connectivity=1) if compartment.lower() == "cell": for label in cupy.unique(mask_cp): label = int(label) if label == 0: continue # skip background tmp_mask = mask_cp == label # 3D enclosed-hole fill tmp_mask = cupyx.scipy.ndimage.binary_fill_holes( tmp_mask, structure=structure_3d, ) # 2D per-z enclosed-hole fill (captures holes open across z in 3D) for z in range(tmp_mask.shape[0]): tmp_mask[z] = cupyx.scipy.ndimage.binary_fill_holes(tmp_mask[z]) new_mask_cp[tmp_mask] = label # convert back to numpy array mask = cupy.asnumpy(new_mask_cp).astype(mask.dtype) elif compartment.lower() == "organoid": # for organoid masks # 3D enclosed-hole fill new_mask_cp = cupyx.scipy.ndimage.binary_fill_holes( mask_cp, structure=structure_3d, ) # 2D per-z enclosed-hole fill (captures holes open across z in 3D) for z in range(new_mask_cp.shape[0]): new_mask_cp[z] = cupyx.scipy.ndimage.binary_fill_holes(new_mask_cp[z]) # convert back to numpy array mask = cupy.asnumpy(new_mask_cp).astype(mask.dtype) # define labels now mask, _ = scipy.ndimage.label(mask) return mask
[docs] def segment_cells_with_3D_watershed( cyto_signal: np.ndarray, nuclei_mask: np.ndarray, thresholded_signal: np.ndarray, connectivity: int = 1, compactness: float = 0, ) -> np.ndarray: """ Segment cells using 3D watershed algorithm. Segments cells using a 3D watershed algorithm given cytoplasm signal (channel) and nuclei mask. Parameters ---------- cyto_signal : np.ndarray 3D numpy array representing the cytoplasm signal. nuclei_mask : np.ndarray 3D numpy array representing the nuclei mask. thresholded_signal : np.ndarray 3D numpy array representing the thresholded cytoplasm signal to be used as a mask for watershed. connectivity : int, optional Connectivity parameter for the watershed algorithm. Default is 1. A value of 1 means only directly adjacent pixels (6-connectivity in 3D) are considered connected, preventing over-segmentation. compactness : float, optional Compactness parameter controlling watershed region shape. Default is 0. A value of 0 means no compactness enforcement, allowing irregularly shaped segments to capture true cell morphology. Returns ------- np.ndarray 3D numpy array representing the segmented cell mask. """ labels = skimage.segmentation.watershed( image=cyto_signal, markers=nuclei_mask, # connectivity parameter controls how pixels are connected in the watershed algorithm. # A value of 1 means that only directly adjacent pixels (6-connectivity in 3D) are considered connected, # which is appropriate for cell segmentation to prevent over-segmentation. connectivity=connectivity, # keep at 1 # compactness parameter controls the shape of the watershed regions. # A value of 0 means that the watershed will not enforce compactness, # allowing for more irregularly shaped segments, # which is often desirable in cell segmentation to capture the true morphology of cells. compactness=compactness, mask=thresholded_signal, ) # change the largest label (by area) to 0 # cleans up the output and sets the background properly unique, counts = np.unique(labels, return_counts=True) largest_label = unique[np.argmax(counts)] labels[labels == largest_label] = 0 return labels
[docs] def perform_morphology_dependent_segmentation( organoid_label: str, cyto_signal: np.ndarray, nuclei_mask: np.ndarray, min_size: int = 1_000, max_size: int = 10_000_000, ) -> np.ndarray: """ Perform morphology-dependent cell segmentation. Performs morphology dependent segmentation based on the provided morphology label. Parameters ---------- organoid_label : str Morphology label indicating the type of morphology. cyto_signal : np.ndarray 3D numpy array representing the cytoplasm signal. nuclei_mask : np.ndarray 3D numpy array representing the nuclei mask. min_size : int, optional Minimum size threshold for segmented objects. Default is 1,000 voxels. max_size : int, optional Maximum size threshold for segmented objects. Default is 10,000,000 voxels. Returns ------- np.ndarray 3D numpy array representing the segmented cell mask. """ if organoid_label in {"failed", "blank"}: print("Failed/blank morphology selected, skipping cell segmentation") return np.zeros_like(cyto_signal, dtype=np.int32) # generate the low frequency elevation map # all morphology types use the same initial elevation map butterworth_filter = skimage.filters.butterworth( cyto_signal, cutoff_frequency_ratio=0.08, order=2, high_pass=False, squared_butterworth=False, ) if organoid_label in {"globular", "cluster"}: # apply gaussian filter to smooth the signal for better thresholding elevation_map_threshold_signal = skimage.filters.gaussian( cyto_signal, sigma=2.5 ) threshold = skimage.filters.threshold_otsu(elevation_map_threshold_signal) elevation_map_threshold_signal[elevation_map_threshold_signal < threshold] = 0 elevation_map_threshold_signal[elevation_map_threshold_signal > 0] = 1 elevation_map = sobel(butterworth_filter) connectivity = 1 compactness = 1 elif organoid_label in {"small", "dissociated"}: elevation_map_threshold_signal = skimage.filters.gaussian(cyto_signal, sigma=3) threshold = skimage.filters.threshold_otsu(elevation_map_threshold_signal) elevation_map_threshold_signal[elevation_map_threshold_signal < threshold] = 0 elevation_map_threshold_signal[elevation_map_threshold_signal > 0] = 1 elevation_map_threshold_signal = skimage.morphology.dilation( elevation_map_threshold_signal, skimage.morphology.ball(1), ) elevation_map = skimage.filters.gaussian(butterworth_filter, sigma=1) elevation_map = sobel(elevation_map) connectivity = 1 compactness = 0 elif organoid_label in {"elongated"}: elevation_map_threshold_signal = skimage.filters.gaussian(cyto_signal, sigma=4) threshold = skimage.filters.threshold_otsu(elevation_map_threshold_signal) elevation_map_threshold_signal[elevation_map_threshold_signal < threshold] = 0 elevation_map_threshold_signal[elevation_map_threshold_signal > 0] = 1 elevation_map_threshold_signal = skimage.morphology.dilation( elevation_map_threshold_signal, skimage.morphology.ball(10), ) elevation_map = skimage.filters.gaussian(butterworth_filter, sigma=1) elevation_map = sobel(elevation_map) connectivity = 0 compactness = 0 else: raise ValueError(f"Unknown morphology label: {organoid_label}") cell_mask = segment_cells_with_3D_watershed( cyto_signal=elevation_map, nuclei_mask=nuclei_mask, connectivity=connectivity, compactness=compactness, thresholded_signal=elevation_map_threshold_signal, ) cell_mask = fill_holes_in_mask(cell_mask, compartment="cell") # Remove small objects while preserving label IDs # we avoid using the built-in skimage function to preserve label IDs props = skimage.measure.regionprops(cell_mask) # Remove objects smaller than threshold for prop in props: if prop.area < min_size: # min size threshold (adjust as needed) cell_mask[cell_mask == prop.label] = 0 # remove large objects unique, counts = np.unique(cell_mask[cell_mask > 0], return_counts=True) for label, count in zip(unique, counts): if count > max_size: cell_mask[cell_mask == label] = 0 return cell_mask