from __future__ import annotations
from typing import List, Tuple
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import pandas as pd
import skimage
import tqdm
from image_analysis_3D.segmentation_utils.segmentation_decoupling import (
euclidian_2D_distance,
extract_unique_masks,
get_combinations_of_indices,
merge_sets_df,
reassemble_each_mask,
)
from skimage.filters import sobel
# ----------------------------------------------------------------------
# convert to 2.5 D image stack
# ----------------------------------------------------------------------
[docs]
def sliding_window_two_point_five_D(
image_stack: np.ndarray, window_size: int
) -> np.ndarray:
"""Create 2.5D max-projection stack using a sliding window.
Parameters
----------
image_stack : np.ndarray
Input 3D image stack (Z, Y, X).
window_size : int
Number of slices to project per window.
Returns
-------
np.ndarray
2.5D stack of max projections.
"""
image_stack_2_5D = np.empty(
(0, image_stack.shape[1], image_stack.shape[2]), dtype=image_stack.dtype
)
for image_index in range(image_stack.shape[0]):
image_stack_window = image_stack[image_index : image_index + window_size]
if not image_stack_window.shape[0] == window_size:
break
# max project the image stack
image_stack_2_5D = np.array(
np.append(
image_stack_2_5D,
np.max(image_stack_window, axis=0)[np.newaxis, :, :],
axis=0,
)
)
return image_stack_2_5D
[docs]
def reverse_sliding_window_max_projection(
output_dict: dict,
window_size: int,
original_z_slice_count: int,
) -> dict:
"""Reconstruct per-slice masks from sliding-window projections.
Parameters
----------
output_dict : dict
Output dictionary with projected labels.
window_size : int
Sliding window size used during projection.
original_z_slice_count : int
Number of slices in the original stack.
Returns
-------
dict
Mapping of slice index to list of reconstructed masks.
"""
# reverse the sliding window
# reverse sliding window max projection
full_mask_z_stack = []
reconstruction_dict = {index: [] for index in range(original_z_slice_count)}
# loop through the sliding window max projected masks and decouple them
for z_stack_mask_index in range(len(output_dict["labels"])):
z_stack_decouple = []
# make n copies of the mask for sliding window decoupling
# where n is the size of the sliding window
[
z_stack_decouple.append(output_dict["labels"][z_stack_mask_index])
for _ in range(window_size)
]
for z_window_index, z_stack_mask in enumerate(z_stack_decouple):
# append the masks to the reconstruction_dict
if not (z_stack_mask_index + z_window_index) >= original_z_slice_count:
reconstruction_dict[z_stack_mask_index + z_window_index].append(
z_stack_mask
)
return reconstruction_dict
# ----------------------------------------------------------------------
# Butterworth filtering function
# ----------------------------------------------------------------------
[docs]
def butterworth_grid_optimization(
img: np.ndarray,
return_plot: bool = False,
) -> None:
"""Sweep Butterworth parameters and optionally plot results.
Parameters
----------
img : np.ndarray
Image stack used for optimization.
return_plot : bool, optional
Whether to display a parameter grid plot, by default False.
Returns
-------
None
Only displays plots when requested.
"""
# get the median most image from the cyto image stack
# this is the image that will be used for the butterworth filter optimization
# get the median image from the cyto image stack
middle_index = int(img.shape[0] / 2)
img_to_optimize = img[middle_index]
optimization_steps = 5
# optimize the butterworth filter for the cyto image
search_space_cutoff_freq = np.linspace(0.01, 0.5, optimization_steps)
search_space_order = np.linspace(1, 10, optimization_steps)
# create a list of optimzation parameter pairs
optimization_parameter_pairs = []
for cutoff_freq_option in search_space_cutoff_freq:
for order_option in search_space_order:
optimization_parameter_pairs.append((cutoff_freq_option, order_option))
optimized_images = []
# loop through the optimization pairs to find the best pararmeters
for cutoff_freq_option, order_option in tqdm.tqdm(optimization_parameter_pairs):
optimized_images.append(
skimage.filters.butterworth(
img_to_optimize,
cutoff_frequency_ratio=cutoff_freq_option,
high_pass=False,
order=order_option,
squared_butterworth=True,
)
)
if return_plot:
# visualize the optimized images in a grid
fig, ax = plt.subplots(optimization_steps, optimization_steps, figsize=(20, 20))
for i in range(optimization_steps):
for j in range(optimization_steps):
ax[i, j].imshow(optimized_images[i * optimization_steps + j])
ax[i, j].axis("off")
# add the cutoff frequency and order to the plot
ax[i, j].set_title(
f"Freq: {search_space_cutoff_freq[i]:.2f}, Order: {search_space_order[j]:.2f}"
)
plt.show()
[docs]
def apply_butterworth_filter(
img: np.ndarray,
cutoff_frequency_ratio: float = 0.05,
order: int = 1,
high_pass: bool = False,
squared_butterworth: bool = True,
) -> np.ndarray:
"""Apply a Butterworth filter and Gaussian smoothing.
Parameters
----------
img : np.ndarray
Input image stack to filter.
cutoff_frequency_ratio : float, optional
Cutoff frequency ratio, by default 0.05.
order : int, optional
Butterworth filter order, by default 1.
high_pass : bool, optional
Whether to use a high-pass filter, by default False.
squared_butterworth : bool, optional
Use squared Butterworth response, by default True.
Returns
-------
np.ndarray
Filtered image stack.
"""
# Use butterworth FFT filter to remove high frequency noise :)
for i in range(img.shape[0]):
img[i, :, :] = skimage.filters.butterworth(
img[i, :, :],
cutoff_frequency_ratio=cutoff_frequency_ratio,
high_pass=high_pass,
order=order,
squared_butterworth=squared_butterworth,
)
# add a guassian blur to the image
img = skimage.filters.gaussian(img, sigma=1)
return img
# ----------------------------------------------------------------------
# decoupling segmented masks
# ----------------------------------------------------------------------
[docs]
def decouple_masks(
reconstruction_dict: dict,
original_img_shape: np.ndarray,
distance_threshold: int,
verbose: bool = False,
) -> dict:
"""Decouple projected masks into per-slice masks.
Parameters
----------
reconstruction_dict : dict
Mapping of slice index to projected masks.
original_img_shape : np.ndarray
Shape of the original image stack.
distance_threshold : int
Distance threshold for mask merging.
verbose : bool, optional
Whether to print warnings, by default False.
Returns
-------
dict
Mapping of slice index to reassembled masks.
"""
masks_dict = {}
for zslice, arrays in tqdm.tqdm(
enumerate(reconstruction_dict), total=len(reconstruction_dict)
):
df = extract_unique_masks(reconstruction_dict[zslice])
merged_df = get_combinations_of_indices(
df, distance_threshold=distance_threshold
)
# combine dfs for each window index
merged_df = merge_sets_df(merged_df)
if not merged_df.empty:
merged_df.loc[:, "slice"] = zslice
reassembled_masks = reassemble_each_mask(
merged_df, original_img_shape=original_img_shape
)
masks_dict[zslice] = reassembled_masks
else:
if verbose:
print(f"Warning: merged_df is empty for zslice {zslice}")
masks_dict[zslice] = reconstruction_dict[zslice][0]
return masks_dict
# ------------------------------------------------------
# reconstruct full 3D masks from decoupled masks
# ------------------------------------------------------
[docs]
def generate_coordinates_for_reconstruction(image: np.ndarray) -> pd.DataFrame:
"""Generate centroid and bounding-box coordinates for reconstruction.
Parameters
----------
image : np.ndarray
Labeled mask stack (Z, Y, X).
Returns
-------
pd.DataFrame
DataFrame of labels, centroids, and bounding boxes.
"""
cordinates = {
"original_label": [],
"slice": [],
"centroid-0": [],
"centroid-1": [],
"bbox-0": [],
"bbox-1": [],
"bbox-2": [],
"bbox-3": [],
}
for image_slice in range(image.shape[0]):
props = skimage.measure.regionprops_table(
image[image_slice, :, :], properties=["label", "centroid", "bbox"]
)
label, centroid1, centroid2, bbox0, bbox1, bbox2, bbox3 = (
props["label"],
props["centroid-0"],
props["centroid-1"],
props["bbox-0"],
props["bbox-1"],
props["bbox-2"],
props["bbox-3"],
)
if len(label) > 0:
for i in range(len(label)):
cordinates["original_label"].append(label[i])
cordinates["slice"].append(image_slice)
cordinates["centroid-0"].append(centroid1[i])
cordinates["centroid-1"].append(centroid2[i])
cordinates["bbox-0"].append(bbox0[i])
cordinates["bbox-1"].append(bbox1[i])
cordinates["bbox-2"].append(bbox2[i])
cordinates["bbox-3"].append(bbox3[i])
coordinates_df = pd.DataFrame(cordinates)
coordinates_df["unique_id"] = coordinates_df.index
return coordinates_df
[docs]
def generate_distance_pairs(
coordinates_df: pd.DataFrame, x_y_vector_radius_max_constraint: int
):
"""Create distance pairs for centroid matching across slices.
Parameters
----------
coordinates_df : pd.DataFrame
DataFrame containing centroid coordinates.
x_y_vector_radius_max_constraint : int
Maximum centroid distance to include.
Returns
-------
pd.DataFrame
Pairwise centroid distances within the constraint.
"""
# generate distance pairs for each slice
distance_pairs = {
"slice1": [],
"slice2": [],
"index1": [],
"index2": [],
"distance": [],
"coordinates1": [],
"coordinates2": [],
"pass": [],
"original_label1": [],
"original_label2": [],
}
distance_pairs_list = [
{
"slice1": coordinates_df.loc[i, "slice"],
"slice2": coordinates_df.loc[j, "slice"],
"index1": i,
"index2": j,
"distance": euclidian_2D_distance(
coordinates_df.loc[i, ["centroid-0", "centroid-1"]].values,
coordinates_df.loc[j, ["centroid-0", "centroid-1"]].values,
),
"coordinates1": tuple(
coordinates_df.loc[i, ["centroid-0", "centroid-1"]].values
),
"coordinates2": tuple(
coordinates_df.loc[j, ["centroid-0", "centroid-1"]].values
),
"pass": True,
"original_label1": coordinates_df.loc[i, "original_label"],
"original_label2": coordinates_df.loc[j, "original_label"],
}
for i in range(coordinates_df.shape[0])
for j in range(coordinates_df.shape[0])
if i != j
and euclidian_2D_distance(
coordinates_df.loc[i, ["centroid-0", "centroid-1"]].values,
coordinates_df.loc[j, ["centroid-0", "centroid-1"]].values,
)
< x_y_vector_radius_max_constraint
]
# Convert to DataFrame (if needed)
df = pd.DataFrame(distance_pairs_list)
if not df.empty:
df["indexes"] = df["index1"].astype(str) + "-" + df["index2"].astype(str)
df = df[df["pass"] == True]
df["index_comparison"] = (
df["index1"].astype(str) + "," + df["index2"].astype(str)
)
df.head()
return df
def calculate_bbox_area(bbox: Tuple[int, int, int, int]) -> int:
"""
Calculate the area of a bounding box.
Parameters
----------
bbox : Tuple[int, int, int, int]
The bounding box coordinates in the format (x_min, y_min, x_max, y_max).
Returns
-------
int
The area of the bounding box.
"""
return max(0, bbox[2] - bbox[0]) * max(0, bbox[3] - bbox[1])
def calculate_overlap(
bbox1: tuple[int, int, int, int], bbox2: tuple[int, int, int, int]
) -> float:
"""
Calculate the percentage overlap between two bounding boxes.
Parameters
----------
bbox1 : Tuple[int, int, int, int]
The first bounding box (x_min, y_min, x_max, y_max).
bbox2 : Tuple[int, int, int, int]
The second bounding box (x_min, y_min, x_max, y_max).
Returns
-------
float
The percentage overlap of the smaller bounding box with the larger one.
"""
# Calculate intersection coordinates
x_min = max(bbox1[0], bbox2[0])
y_min = max(bbox1[1], bbox2[1])
x_max = min(bbox1[2], bbox2[2])
y_max = min(bbox1[3], bbox2[3])
# Calculate intersection area
overlap_width = max(0, x_max - x_min)
overlap_height = max(0, y_max - y_min)
overlap_area = overlap_width * overlap_height
# Calculate areas of both bounding boxes
area1 = calculate_bbox_area(bbox1)
area2 = calculate_bbox_area(bbox2)
# Return the percentage overlap relative to the smaller bounding box
smaller_area = min(area1, area2)
return overlap_area / smaller_area if smaller_area > 0 else 0.0
[docs]
def calculate_mask_iou(mask1: np.ndarray, mask2: np.ndarray) -> bool:
"""
Calculate the Intersection over Union (IoU) between two binary masks.
Parameters
----------
mask1 : np.ndarray
The first binary mask.
mask2 : np.ndarray
The second binary mask.
Returns
-------
bool
True if the IoU is greater than 0.5, False otherwise.
"""
intersection = np.logical_and(mask1, mask2)
union = np.logical_or(mask1, mask2)
if np.sum(union) == 0:
return False
iou = np.sum(intersection) / np.sum(union)
return iou
[docs]
def graph_creation(df: pd.DataFrame) -> nx.Graph:
"""Build a graph connecting centroid pairs.
Parameters
----------
df : pd.DataFrame
DataFrame of centroid pairs and distances.
Returns
-------
networkx.Graph
Graph with centroid nodes and distance edges.
"""
# create a graph where each node is a unique centroid and each edge is a distance between centroids
# edges between nodes with the same slice are not allowed
# edge weight is the distance between the nodes (euclidian distance)
G = nx.Graph()
for row in df.iterrows():
G.add_node(
row[1]["index1"], slice=row[1]["slice1"], coordinates=row[1]["coordinates1"]
)
G.add_node(
row[1]["index2"], slice=row[1]["slice2"], coordinates=row[1]["coordinates2"]
)
G.add_edge(
row[1]["index1"],
row[1]["index2"],
weight=row[1]["distance"],
original_label1=row[1]["original_label1"],
original_label2=row[1]["original_label2"],
)
# plot the graph with each slice being on a different row
pos = nx.spring_layout(G)
edge_labels = nx.get_edge_attributes(G, "weight")
return G
[docs]
def solve_graph(G: nx.Graph) -> list[list]:
"""Solve for longest shortest paths in a graph.
Parameters
----------
G : networkx.Graph
Graph of centroid connections.
Returns
-------
list[list[int]]
Longest paths discovered in the graph.
"""
# solve the the shortest path problem
# find the longest paths in the graph with the smallest edge weights
# this will find the longest paths between centroids closest to each other
# the longest path is the path with the most edges
longest_paths = []
for path in nx.all_pairs_shortest_path(G, cutoff=10):
longest_path = []
for key in path[1].keys():
if len(path[1][key]) > len(longest_path):
longest_path = path[1][key]
longest_paths.append(longest_path)
return longest_paths
[docs]
def merge_sets(list_of_sets: list) -> list:
"""Merge overlapping sets of node indices.
Parameters
----------
list_of_sets : list
List of set objects to merge.
Returns
-------
list
Updated list of merged sets.
"""
for i, set1 in enumerate(list_of_sets):
for j, set2 in enumerate(list_of_sets):
if i != j and len(set1.intersection(set2)) > 0:
set1.update(set2)
return list_of_sets
[docs]
def collapse_labels(df: pd.DataFrame, longest_paths: list) -> pd.DataFrame:
"""Collapse labels using graph paths.
Parameters
----------
df : pd.DataFrame
DataFrame containing unique IDs.
longest_paths : list
List of paths from graph solution.
Returns
-------
pd.DataFrame
Updated DataFrame with collapsed labels.
"""
list_of_sets = [set(x) for x in longest_paths]
merged_sets = merge_sets(list_of_sets)
merged_sets_dict = {}
for i in range(len(list_of_sets)):
merged_sets_dict[i] = list_of_sets[i]
for row in df.iterrows():
for num_set in merged_sets_dict:
if int(row[1]["unique_id"]) in merged_sets_dict[num_set]:
df.at[row[0], "label"] = num_set
# drop nan
df = df.dropna()
return df
[docs]
def reassign_labels(
image: np.ndarray,
df: pd.DataFrame,
):
"""Reassign labels in a mask based on a mapping DataFrame.
Parameters
----------
image : np.ndarray
Mask stack to relabel.
df : pd.DataFrame
DataFrame containing label mappings by slice.
Returns
-------
np.ndarray
Relabeled mask stack.
"""
new_mask_image = np.zeros_like(image)
# mask label reassignment
for image_slice in range(image.shape[0]):
mask = image[image_slice, :, :]
tmp_df = df[df["slice"] == image_slice]
if tmp_df.empty:
continue
# check if label is present or if reassignment is needed
if "label" not in tmp_df.columns:
continue
for i in range(tmp_df.shape[0]):
mask[mask == tmp_df.iloc[i]["original_label"]] = tmp_df.iloc[i]["label"]
new_mask_image[image_slice, :, :] = mask
return new_mask_image
# ----------------------------------------------------------------------
# post hoc refinements
# ----------------------------------------------------------------------
[docs]
def calculate_bbox_area(bbox: Tuple[int, int, int, int]) -> int:
"""
Calculate the area of a bounding box.
Parameters
----------
bbox : Tuple[int, int, int, int]
The bounding box coordinates in the format (x_min, y_min, x_max, y_max).
Returns
-------
int
The area of the bounding box.
"""
return (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
[docs]
def calculate_overlap(
bbox1: Tuple[int, int, int, int], bbox2: Tuple[int, int, int, int]
) -> float:
"""Calculate overlap percentage between two bounding boxes.
Parameters
----------
bbox1 : Tuple[int, int, int, int]
First bounding box.
bbox2 : Tuple[int, int, int, int]
Second bounding box.
Returns
-------
float
Overlap percentage relative to the smaller box.
"""
# calculate the % overlap of the second bbox with the first bbox
if calculate_bbox_area(bbox1) == 0 or calculate_bbox_area(bbox2) == 0:
return 0.0
if calculate_bbox_area(bbox1) >= calculate_bbox_area(bbox2):
x_min = max(bbox1[0], bbox2[0])
y_min = max(bbox1[1], bbox2[1])
x_max = min(bbox1[2], bbox2[2])
y_max = min(bbox1[3], bbox2[3])
overlap_width = max(0, x_max - x_min)
overlap_height = max(0, y_max - y_min)
overlap_area = overlap_width * overlap_height
bbox1_area = calculate_bbox_area(bbox1)
bbox2_area = calculate_bbox_area(bbox2)
overlap_percentage = overlap_area / bbox2_area if bbox2_area > 0 else 0
return overlap_percentage
elif calculate_bbox_area(bbox1) < calculate_bbox_area(bbox2):
x_min = max(bbox1[0], bbox2[0])
y_min = max(bbox1[1], bbox2[1])
x_max = min(bbox1[2], bbox2[2])
y_max = min(bbox1[3], bbox2[3])
overlap_width = max(0, x_max - x_min)
overlap_height = max(0, y_max - y_min)
overlap_area = overlap_width * overlap_height
bbox1_area = calculate_bbox_area(bbox1)
bbox2_area = calculate_bbox_area(bbox2)
overlap_percentage = overlap_area / bbox1_area if bbox1_area > 0 else 0
return overlap_percentage
else:
print("Error: Bboxes are the same size")
[docs]
def check_for_all_same_labels(
object_information_df: pd.DataFrame,
) -> bool:
"""
Check if all labels in the object information DataFrame are the same.
Parameters
----------
object_information_df : pd.DataFrame
The DataFrame containing object information with 'label' column.
Returns
-------
bool
True if all labels are the same, False otherwise.
"""
return object_information_df["label"].nunique() == 1
[docs]
def missing_slice_check(
object_information_df: pd.DataFrame,
window_min: int = 0,
window_max: int = 2,
interpolated_rows_to_add: List[int] = [],
) -> List[pd.DataFrame]:
"""
Check for missing slices in the object information DataFrame and add interpolated rows if necessary.
Parameters
----------
object_information_df : pd.DataFrame
The DataFrame containing object information with 'z' and 'label' columns.
window_min : int, optional
The minimum window size for checking missing slices, by default 0
window_max : int, optional
The maximum window size for checking missing slices, by default 2
interpolated_rows_to_add : List[int], optional
A list to store rows to be added for interpolation, by default []
Returns
-------
List[pd.DataFrame]
A list of DataFrames containing rows to be added for interpolation.
"""
max_z = object_information_df["z"].max()
min_z = object_information_df["z"].min()
if max_z - min_z > 1:
if len(object_information_df) < 3:
# get the first row
row = object_information_df.iloc[0]
new_row = {
"added_z": row["z"],
"added_new_label": row["label"],
"zslice_to_copy": row["z"],
}
# interpolate the labels to the middle most slice
# get the middle slice
middle_slice = int((max_z + min_z) / 2)
# insert one slice
z_zlice_to_copy = row["z"]
new_row = {
# 'index': object_information_df['index'].values[0],
# 'index': object_max_slice_label,
"added_z": middle_slice,
"added_new_label": row["label"],
"zslice_to_copy": z_zlice_to_copy,
}
interpolated_rows_to_add.append(pd.DataFrame(new_row, index=[0]))
return interpolated_rows_to_add
[docs]
def add_min_max_boundry_slices(
object_information_df: pd.DataFrame,
global_min_z: int,
global_max_z: int,
interpolated_rows_to_add: List[pd.DataFrame] = [],
) -> List[pd.DataFrame]:
"""
Add slices to the object information DataFrame that are one slice away from the global min and max z slices.
Parameters
----------
object_information_df : pd.DataFrame
The DataFrame containing object information with 'z' and 'label' columns.
global_min_z : int
The global minimum z slice.
global_max_z : int
The global maximum z slice.
interpolated_rows_to_add : List[pd.DataFrame], optional
A list to store rows to be added for interpolation, by default []
Returns
-------
List[pd.DataFrame]
A list of DataFrames containing rows to be added for interpolation at the min and max z slices.
"""
# find labels that are 1 slice away from the min or max and extend the label
for i, row in object_information_df.iterrows():
# check if the z slice is one away from the min or max (global min and max)
if row["z"] == global_max_z - 1:
new_row = {
"added_z": global_max_z,
"added_new_label": row["label"],
"zslice_to_copy": row["z"],
}
interpolated_rows_to_add.append(pd.DataFrame(new_row, index=[0]))
elif row["z"] == global_min_z + 1:
new_row = {
"added_z": global_min_z,
"added_new_label": row["label"],
"zslice_to_copy": row["z"],
}
interpolated_rows_to_add.append(pd.DataFrame(new_row, index=[0]))
return interpolated_rows_to_add
[docs]
def add_masks_where_missing(
new_mask_image: np.ndarray,
interpolated_rows_to_add_df: pd.DataFrame,
) -> np.ndarray:
"""
Add masks to the new mask image where the slices are missing based on the interpolated rows.
Parameters
----------
new_mask_image : np.ndarray
The new mask image to which the slices will be added.
interpolated_rows_to_add_df : pd.DataFrame
The DataFrame containing the rows to be added for interpolation, with columns 'added_z', 'added_new_label', and 'zslice_to_copy'.
Returns
-------
np.ndarray
The new mask image with the added slices.
"""
for image_slice in interpolated_rows_to_add_df["added_z"].unique():
# get the rows that correspond to the slice
tmp_df = interpolated_rows_to_add_df[
interpolated_rows_to_add_df["added_z"] == image_slice
]
if tmp_df.shape[0] == 0:
continue
for i, row in tmp_df.iterrows():
# get the z slice to copy mask
new_slice = new_mask_image[row["zslice_to_copy"].astype(int), :, :].copy()
new_slice[new_slice != row["added_new_label"]] = 0
old_slice = new_mask_image[row["added_z"].astype(int), :, :].copy()
max_projected_slice = np.maximum(old_slice, new_slice)
new_mask_image[row["added_z"].astype(int), :, :] = max_projected_slice
return new_mask_image
[docs]
def reorder_organoid_labels(
label_image: np.ndarray,
) -> np.ndarray:
"""
Reorder the labels in the label image to ensure they are sequential starting from 1.
Parameters
----------
label_image : np.ndarray
The label image where labels need to be reordered.
Returns
-------
np.ndarray
The label image with reordered labels.
"""
unique_labels = np.unique(label_image)
# remove the background label (0)
unique_labels = unique_labels[unique_labels != 0]
# exit early if there are no labels (only background)
if len(unique_labels) == 0:
return label_image
# create a mapping from old label to new label
label_mapping = {
old_label: new_label
for new_label, old_label in enumerate(unique_labels, start=1)
}
label_image_corrected = np.copy(label_image)
for old_label, new_label in label_mapping.items():
label_image_corrected[label_image == old_label] = new_label
return label_image_corrected
[docs]
def run_post_hoc_refinement(
mask_image: List[int],
sliding_window_context: int,
) -> np.ndarray:
"""Refine labels by interpolating across missing slices.
Parameters
----------
mask_image : List[int]
3D labeled mask stack.
sliding_window_context : int
Number of slices to consider in the sliding context.
Returns
-------
np.ndarray
Refined mask stack.
"""
new_mask_image = mask_image.copy()
global_max_z = mask_image.shape[0] # number of z slices
global_min_z = 0
# expand the z slices into a list of slices between the min and max z slices
z_slices = [x for x in range(global_min_z, global_max_z)]
for z in z_slices[: -(sliding_window_context - 1)]:
interpolated_rows_to_add = []
final_dict = {
"index1": [],
"index2": [],
"z1": [],
"z2": [],
"distance": [],
"label1": [],
"label2": [],
}
list_of_cell_masks = []
for z_slice in range(0, new_mask_image.shape[0] - 1):
compartment_df = pd.DataFrame.from_dict(
skimage.measure.regionprops_table(
new_mask_image[z, :, :],
properties=["centroid", "bbox"],
)
)
compartment_df["z"] = z_slice
list_of_cell_masks.append(compartment_df)
compartment_df = pd.concat(list_of_cell_masks)
# get the pixel value of the organoid mask at each x,y,z coordinate
compartment_df["label"] = new_mask_image[
compartment_df["z"].astype(int),
compartment_df["centroid-0"].astype(int),
compartment_df["centroid-1"].astype(int),
]
compartment_df.reset_index(drop=True, inplace=True)
compartment_df["new_label"] = compartment_df["label"]
# drop all labels that are 0
compartment_df = compartment_df[compartment_df["label"] != 0]
# Get the temporary sliding window
tmp_window_df = compartment_df[
(compartment_df["z"] >= z)
& (compartment_df["z"] < z + sliding_window_context)
]
if tmp_window_df["z"].nunique() < sliding_window_context:
continue
for i, row1 in tmp_window_df.iterrows():
for j, row2 in tmp_window_df.iterrows():
if i != j: # Ensure you're not comparing the same row
if row1["z"] != row2["z"]:
# get the first bbox
distance = euclidian_2D_distance(
(row1["centroid-0"], row1["centroid-1"]),
(row2["centroid-0"], row2["centroid-1"]),
)
if distance < 20:
final_dict["index1"].append(i)
final_dict["index2"].append(j)
final_dict["z1"].append(row1["z"])
final_dict["z2"].append(row2["z"])
final_dict["distance"].append(distance)
final_dict["label1"].append(row1["label"])
final_dict["label2"].append(row2["label"])
final_df = pd.DataFrame.from_dict(final_dict)
final_df["index_set"] = final_df.apply(
lambda row: frozenset([row["index1"], row["index2"]]), axis=1
)
final_df["index_set"] = final_df["index_set"].apply(lambda x: tuple(sorted(x)))
list_of_sets = final_df["index_set"].tolist()
list_of_sets = [set(s) for s in list_of_sets]
merged_sets = merge_sets(list_of_sets)
# drop the duplicates
merged_sets = list({frozenset(s): s for s in merged_sets}.values())
# from final_df generate the z-ordered cases
for object_set in merged_sets:
# find rows that contain integers that are in the object_set
rows_that_contain_object_set = final_df[
final_df["index_set"].apply(lambda x: set(x).issubset(object_set))
]
# get the index, label, and z pair
dict_of_object_information = {"index": [], "label": [], "z": []}
for i, row in rows_that_contain_object_set.iterrows():
dict_of_object_information["index"].append(row["index1"])
dict_of_object_information["label"].append(row["label1"])
dict_of_object_information["z"].append(row["z1"])
dict_of_object_information["index"].append(row["index2"])
dict_of_object_information["label"].append(row["label2"])
dict_of_object_information["z"].append(row["z2"])
object_information_df = pd.DataFrame.from_dict(dict_of_object_information)
object_information_df.drop_duplicates(
subset=["index", "label", "z"], inplace=True
)
object_information_df.sort_values(by=["index", "z"], inplace=True)
if check_for_all_same_labels(object_information_df):
# if all labels are the same, skip this object
continue
interpolated_rows_to_add = missing_slice_check(
object_information_df, interpolated_rows_to_add=interpolated_rows_to_add
)
interpolated_rows_to_add = add_min_max_boundry_slices(
object_information_df,
global_min_z=global_min_z,
global_max_z=global_max_z,
interpolated_rows_to_add=interpolated_rows_to_add,
)
if len(interpolated_rows_to_add) == 0:
if z == z_slices[-1]:
# tifffile.imwrite(mask_output_path, new_mask_image)
return new_mask_image
else:
continue
interpolated_rows_to_add_df = pd.concat(interpolated_rows_to_add, axis=0)
new_mask_image = new_mask_image.copy()
new_mask_image = add_masks_where_missing(
new_mask_image=new_mask_image,
interpolated_rows_to_add_df=interpolated_rows_to_add_df,
)
return new_mask_image
# ----------------------------------------------------------------------
# Segment the cells with 3D watershed
# ----------------------------------------------------------------------
[docs]
def segment_cells_with_3D_watershed(
cyto_signal: np.ndarray,
nuclei_mask: np.ndarray,
) -> np.ndarray:
"""Segment cells using seeded 3D watershed.
Parameters
----------
cyto_signal : np.ndarray
Cytoplasm signal image stack.
nuclei_mask : np.ndarray
Nuclei mask used as watershed seeds.
Returns
-------
np.ndarray
Cell segmentation mask.
"""
# gaussian filter to smooth the image
cell_signal_image = skimage.filters.gaussian(cyto_signal, sigma=1.0)
# scale the pixels to max 255
nuclei_mask = (nuclei_mask / nuclei_mask.max() * 255).astype(np.uint8)
# generate the elevation map using the Sobel filter
elevation_map = sobel(cell_signal_image)
# set up seeded watersheding where the nuclei masks are used as seeds
# note: the cytoplasm is used as the signal for this.
labels = skimage.segmentation.watershed(
image=elevation_map,
markers=nuclei_mask,
)
# change the largest label (by area) to 0
# cleans up the output and sets the background properly
unique, counts = np.unique(labels, return_counts=True)
largest_label = unique[np.argmax(counts)]
labels[labels == largest_label] = 0
cell_mask = labels.copy()
cell_mask = run_post_hoc_refinement(
mask_image=cell_mask,
sliding_window_context=3,
)
return labels
# ----------------------------------------------------------------------
# post hoc reassignments
# ----------------------------------------------------------------------
[docs]
def remove_edge_cases(
mask: np.ndarray,
border: int = 10,
) -> np.ndarray:
"""
Remove masks that are image edge cases
In this case - the edge literally means the edge of the image
This is useful to remove masks that are not fully contained within the image
Parameters
----------
mask : np.ndarray
The mask to process, should be a 3D numpy array
border : int, optional
The number of pixels in width to create border to scan for edge cased, by default 10
Returns
-------
np.ndarray
The mask with edge cases removed
"""
edge_pixels = np.concatenate(
[
# all of z, last n rows (y), all columns (x) - bottom edge
mask[:, -border:, :].flatten(),
# all of z, first n rows (y), all columns (x) - top edge
mask[:, 0:border, :].flatten(),
# all of z, all rows (y), first n columns (x) - left edge
mask[:, :, 0:border:].flatten(),
# all of z, all rows (y), last n columns (x) - right edge
mask[:, :, -border:].flatten(),
# each are the edges stacked for the whole volume -> no need to specify every z slice or 3D edge
]
)
# get unique edge pixel values
edge_pixels = np.unique(edge_pixels[edge_pixels > 0])
for edge_pixel_case in edge_pixels:
# make the edge cases equal to zero
mask[mask == edge_pixel_case] = 0
# return the mask with edge cases removed
return mask
[docs]
def centroid_within_bbox_detection(
centroid: tuple,
bbox: tuple,
) -> bool:
"""
Check if the centroid is within the bbox
Parameters
----------
centroid : tuple
Centroid of the object in the order of (z, y, x)
Order of the centroid is important
bbox : tuple
Where the bbox is in the order of (z_min, y_min, x_min, z_max, y_max, x_max)
Order of the bbox is important
Returns
-------
bool
True if the centroid is within the bbox, False otherwise
"""
z_min, y_min, x_min, z_max, y_max, x_max = bbox
z, y, x = centroid
# check if the centroid is within the bbox
if (
z >= z_min
and z <= z_max
and y >= y_min
and y <= y_max
and x >= x_min
and x <= x_max
):
return True
else:
return False
[docs]
def check_if_centroid_within_mask(
centroid: tuple, mask: np.ndarray, label: int
) -> bool:
"""
Check if the centroid is within the mask
Parameters
----------
centroid : tuple
Centroid of the object in the order of (z, y, x)
Order of the centroid is important
mask : np.ndarray
The mask to check against
Returns
-------
bool
True if the centroid is within the mask, False otherwise
"""
z, y, x = centroid
z = np.round(z).astype(int)
y = np.round(y).astype(int)
x = np.round(x).astype(int)
# check if the centroid is within the segmentation mask
cell_label = mask[z, y, x]
if cell_label > 0 and cell_label == label:
return True
else:
return False
[docs]
def get_labels_for_post_hoc_reassignment(
compartment_mask: np.ndarray,
compartment_name: str,
) -> pd.DataFrame:
"""Collect centroid and bbox data for mask relabeling.
Parameters
----------
compartment_mask : np.ndarray
Labeled mask for a compartment.
compartment_name : str
Name of the compartment.
Returns
-------
pd.DataFrame
DataFrame of centroids, bboxes, and labels.
"""
# get the centroid and bbox of the cell mask
compartment_df = pd.DataFrame.from_dict(
skimage.measure.regionprops_table(
compartment_mask,
properties=["centroid", "bbox"],
)
)
compartment_df["compartment"] = compartment_name
compartment_df["label"] = compartment_mask[
compartment_df["centroid-0"].astype(int),
compartment_df["centroid-1"].astype(int),
compartment_df["centroid-2"].astype(int),
]
# remove all 0 labels
compartment_df = compartment_df[compartment_df["label"] > 0].reset_index(drop=True)
return compartment_df
[docs]
def mask_label_reassignment(
mask_df: pd.DataFrame,
mask_input: np.ndarray,
) -> np.ndarray:
"""
Reassign the labels of the mask based on the mask_df
Parameters
----------
mask_df : pd.DataFrame
DataFrame containing the labels and centroids of the mask
mask_input : np.ndarray
The input mask to reassign the labels to
Returns
-------
np.ndarray
The mask with reassigned labels
"""
for i, row in mask_df.iterrows():
if row["label"] == row["new_label"]:
# if the label is already the new label, skip
continue
mask_input[mask_input == row["label"]] = row["new_label"]
return mask_input
[docs]
def run_post_hoc_mask_reassignment(
nuclei_mask: np.ndarray,
cell_mask: np.ndarray,
nuclei_df: pd.DataFrame,
cell_df: pd.DataFrame,
return_dataframe: bool = False,
) -> tuple[np.ndarray, pd.DataFrame | None]:
"""Reassign nuclei labels based on cell containment.
Parameters
----------
nuclei_mask : np.ndarray
Nuclei segmentation mask.
cell_mask : np.ndarray
Cell segmentation mask.
nuclei_df : pd.DataFrame
DataFrame with nuclei centroids and labels.
cell_df : pd.DataFrame
DataFrame with cell centroids and labels.
return_dataframe : bool, optional
Whether to return the merged DataFrame, by default False.
Returns
-------
tuple[np.ndarray, pd.DataFrame | None]
Updated nuclei mask and optional merged DataFrame.
"""
# if a centroid of the nuclei is inside the cell mask,
# then make the cell retain the label of the nuclei
nuclei_df["new_label"] = nuclei_df["label"].copy()
for i, row in nuclei_df.iterrows():
for j, row2 in cell_df.iterrows():
nuc_contained_in_cell_bool = check_if_centroid_within_mask(
centroid=(
row["centroid-0"],
row["centroid-1"],
row["centroid-2"],
),
mask=cell_mask,
label=row2["label"],
)
if nuc_contained_in_cell_bool:
# if the centroid of the nuclei is within the cell mask,
# then make the cell retain the label of the nuclei
nuclei_df.at[i, "new_label"] = row2["label"]
break
else:
pass
# merge the dataframes
nuclei_and_cell_df = pd.merge(
nuclei_df,
cell_df,
left_on="new_label",
right_on="label",
suffixes=("_nuclei", "_cell"),
)
# remove the edge cases
cell_mask = remove_edge_cases(
mask=cell_mask,
border=10,
)
nuclei_mask = remove_edge_cases(
mask=nuclei_mask,
border=10,
)
# reassign the labels of the cell mask
nuclei_mask = mask_label_reassignment(
mask_df=nuclei_df,
mask_input=nuclei_mask,
)
if return_dataframe:
return nuclei_mask, nuclei_and_cell_df
else:
return nuclei_mask, None
# ----------------------------------------------------------------------
# cytoplasm mask creation
# ----------------------------------------------------------------------
[docs]
def create_cytoplasm_masks(
nuclei_masks: np.ndarray,
cell_masks: np.ndarray,
) -> np.ndarray:
"""Create cytoplasm masks by subtracting nuclei from cells.
Parameters
----------
nuclei_masks : np.ndarray
Nuclei segmentation masks.
cell_masks : np.ndarray
Cell segmentation masks.
Returns
-------
np.ndarray
Cytoplasm masks.
"""
cytoplasm_masks = np.zeros_like(cell_masks)
# filter masks that are not the background
for z_slice_index in range(nuclei_masks.shape[0]):
nuclei_slice_mask = nuclei_masks[z_slice_index]
cell_slice_mask = cell_masks[z_slice_index]
cytoplasm_mask = cell_slice_mask.copy()
cytoplasm_mask[nuclei_slice_mask > 0] = 0 # subtraction happens here
cytoplasm_masks[z_slice_index] = cytoplasm_mask
return cytoplasm_masks
[docs]
def clean_border_objects(
segmentation: np.ndarray, border_width: int = 20
) -> np.ndarray:
"""Remove objects touching the segmentation border.
Parameters
----------
segmentation : np.ndarray
Labeled segmentation mask.
border_width : int, optional
Width of the border region, by default 20.
Returns
-------
np.ndarray
Cleaned segmentation mask.
"""
cleaned_seg = segmentation.copy()
max_z, max_y, max_x = segmentation.shape
border_labels = set()
# check x borders
border_labels.update(np.unique(segmentation[:, :, max_x - border_width :]))
border_labels.update(np.unique(segmentation[:, :, :border_width]))
# check y borders
border_labels.update(np.unique(segmentation[:, :border_width, :]))
border_labels.update(np.unique(segmentation[:, -border_width:, :]))
# remove these labels
for label in border_labels:
if label == 0:
continue
cleaned_seg[segmentation == label] = 0
return cleaned_seg
# if there are any singletons remove the labels in all masks
[docs]
def remove_label_id(mask_image: np.ndarray, label_id_to_remove: int) -> np.ndarray:
"""
Remove the label id
Parameters
----------
mask_image : np.ndarray
Mask image from which to remove the label id
label_id_to_remove : int
Label id to remove from the mask image
Returns
-------
np.ndarray
Mask image with the label id removed
"""
mask_image[mask_image == label_id_to_remove] = 0
return mask_image