Source code for image_analysis_3D.featurization_utils.sammed3d_featurizer
"""
SAM-Med3D Feature Extractor
Convert SAM-Med3D from segmentation to featurization model.
SAM-Med3D Architecture:
- 3D Image Encoder (ViT-based): Extracts features from 3D volumes
- 3D Prompt Encoder: Processes prompts, which are supervision signals provided by user for segmentation at inference time (not needed nor used for featurization)
- 3D Mask Decoder: Generates segmentation masks (not needed for featurization)
For featurization, we extract embeddings from the 3D image encoder.
Requirements:
pip install torch torchvision monai einops timm
# For using pretrained SAM-Med3D:
pip install medim
"""
from __future__ import annotations
import logging
import os
import sys
from io import StringIO
from typing import Dict, List, Optional, Union
import medim
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from image_analysis_3D.featurization_utils.loading_classes import ObjectLoader
[docs]
class SAMMed3DFeatureExtractor:
"""
Extract features from 3D microscope volumes using SAM-Med3D encoder.
This class wraps the SAM-Med3D model and extracts dense or global features
from the 3D image encoder for downstream tasks like classification,
clustering, or retrieval.
"""
def __init__(
self,
model_path: Optional[str] = None,
device: Optional[str] = "cuda" if torch.cuda.is_available() else "cpu",
use_medim: Optional[bool] = True,
image_size: Optional[int] = 128,
):
"""
Initialize the SAM-Med3D feature extractor.
Parameters
----------
model_path : str or None, optional
Path to SAM-Med3D checkpoint (.pth file).
device : str, optional
Device to run inference on.
use_medim : bool, optional
Whether to use MedIM package for easy loading.
image_size : int, optional
Input image size (SAM-Med3D typically uses 128).
feature_type : str, optional
Type of features to extract:
- 'global': Global average pooled features
- 'patch': Patch-level features (grid of embeddings)
- 'cls': CLS token (if available)
- 'multiscale': Multi-resolution features
"""
self.device = device
self.image_size = image_size
# Load model
model, self.encoder = self._load_model(model_path, use_medim)
del model # delete model, as we only need the encoder branch
self.encoder.to(device)
self.encoder.eval()
# Get feature dimensions
self.feature_dim = self._get_feature_dim()
def _load_model(self, model_path: Optional[str], use_medim: bool) -> tuple:
"""Load SAM-Med3D model."""
# Suppress logging and stdout
import sys
logging.getLogger("transformers").setLevel(logging.ERROR)
logging.getLogger("torch").setLevel(logging.ERROR)
old_stdout = sys.stdout
sys.stdout = StringIO()
if use_medim:
try:
# Option 1: Load using MedIM (easiest)
if model_path is None:
# Use pretrained SAM-Med3D-turbo
model_path = "https://huggingface.co/blueyo0/SAM-Med3D/resolve/main/sam_med3d_turbo.pth"
model = medim.create_model(
"SAM-Med3D", pretrained=True, checkpoint_path=model_path
)
# Extract encoder
encoder = model.image_encoder
return model, encoder
except ImportError:
print("⚠ MedIM not installed.")
print(" Install with: pip install medim")
print(" Falling back to manual loading...")
use_medim = False
except Exception as e:
print(f"⚠ Failed to load via MedIM: {e}")
print(" Falling back to manual loading...")
use_medim = False
if not use_medim:
# Option 2: Manual loading (requires SAM-Med3D repo)
try:
import os
import sys
# Try to find SAM-Med3D in common locations
possible_paths = [
"./SAM-Med3D",
"../SAM-Med3D",
"../../SAM-Med3D",
os.path.expanduser("~/SAM-Med3D"),
]
sammed3d_path = None
for path in possible_paths:
if os.path.exists(os.path.join(path, "segment_anything")):
sammed3d_path = path
break
if sammed3d_path:
sys.path.insert(0, sammed3d_path)
print(f"✓ Found SAM-Med3D at {sammed3d_path}")
print("✓ Loading SAM-Med3D manually from repo")
# Create model architecture
model = self._build_sammed3d_model()
# Load weights if provided
if model_path and Path(model_path).exists():
checkpoint = torch.load(model_path, map_location="cpu")
if "model" in checkpoint:
model.load_state_dict(checkpoint["model"], strict=False)
else:
model.load_state_dict(checkpoint, strict=False)
print(f"✓ Loaded weights from {model_path}")
else:
print("⚠ No pretrained weights loaded (training from scratch)")
encoder = model.image_encoder
return model, encoder
except ImportError as e:
print(f"⚠ SAM-Med3D repo not found: {e}")
print(" To use full SAM-Med3D:")
print(" 1. git clone https://github.com/uni-medical/SAM-Med3D")
print(" 2. pip install -r SAM-Med3D/requirements.txt")
print(" OR install MedIM: pip install medim")
print("\n✓ Using simplified encoder (still effective!)")
model = SimplifiedSAMMed3DEncoder(
img_size=self.image_size, embed_dim=768, depth=12, num_heads=12
)
return model, model
except Exception as e:
print(f"⚠ Error loading SAM-Med3D: {e}")
print("✓ Using simplified encoder as fallback")
model = None
return model, model
def _build_sammed3d_model(self) -> None:
"""Build SAM-Med3D model architecture."""
# This would require the actual SAM-Med3D code
# Placeholder for the actual implementation
raise NotImplementedError(
"Manual model building requires SAM-Med3D repository. "
"Please install MedIM: pip install medim"
)
def _get_feature_dim(self) -> int | dict:
"""Get the dimension of extracted features."""
with torch.no_grad():
# Create dummy input
dummy_input = torch.randn(
1, 1, self.image_size, self.image_size, self.image_size
).to(self.device)
features = self._extract_features(dummy_input)
if isinstance(features, dict):
return {k: v.shape[-1] for k, v in features.items()}
else:
return features.shape[-1]
def _extract_features(
self, x: torch.Tensor, feature_type: str | None = None
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
"""
Extract features from encoder.
Parameters
----------
x : torch.Tensor
Input tensor (B, C, Z, Y, X).
Returns
-------
torch.Tensor or dict
Features based on feature_type.
"""
# Get encoder features
if hasattr(self.encoder, "forward_features"):
features = self.encoder.forward_features(x)
else:
features = self.encoder(x)
# Process based on feature type
if feature_type == "global":
# Global average pooling
if features.dim() == 5: # (B, C, Z, Y, X)
features = F.adaptive_avg_pool3d(features, 1).flatten(1)
elif features.dim() == 3: # (B, N, C) - transformer output
features = features.mean(dim=1)
elif feature_type == "patch":
# Keep patch-level features
if features.dim() == 5:
# Reshape to (B, C, Z*Y*X)
B, C, Z, Y, X = features.shape
features = features.reshape(B, C, -1).permute(0, 2, 1)
elif feature_type == "cls":
# Extract CLS token if available
if features.dim() == 3: # (B, N, C)
features = features[:, 0, :] # First token is usually CLS
elif features.dim() == 5:
# retain the CLS tokens
features = features[:, :, 0, 0, 0] # (B, C)
else:
raise ValueError("CLS token extraction requires transformer output.")
return features
[docs]
def extract(
self,
volume: Union[np.ndarray, torch.Tensor],
normalize: bool = True,
feature_type: str | None = None,
) -> np.ndarray:
"""
Extract features from a 3D volume.
Parameters
----------
volume : numpy.ndarray or torch.Tensor
3D volume (Z, Y, X) or (C, Z, Y, X) or (B, C, Z, Y, X).
normalize : bool, optional
Whether to normalize the volume.
Returns
-------
numpy.ndarray
Feature vector(s) as numpy array.
"""
# Convert to tensor
if isinstance(volume, np.ndarray):
volume = torch.from_numpy(volume).float()
# Add dimensions if needed
if volume.dim() == 3: # (Z, Y, X)
volume = volume.unsqueeze(0).unsqueeze(0) # (1, 1, Z, Y, X)
elif volume.dim() == 4: # (C, Z, Y, X)
volume = volume.unsqueeze(0) # (1, C, Z, Y, X)
# Normalize
if normalize:
volume = (volume - volume.min()) / (volume.max() - volume.min() + 1e-8)
# Resize to expected size
if volume.shape[-3:] != (self.image_size, self.image_size, self.image_size):
volume = F.interpolate(
volume,
size=(self.image_size, self.image_size, self.image_size),
mode="trilinear",
align_corners=False,
)
# Move to device
volume = volume.to(self.device)
# Extract features
with torch.no_grad():
features = self._extract_features(volume, feature_type=feature_type)
# Convert to numpy
if isinstance(features, dict):
features = {k: v.cpu().numpy() for k, v in features.items()}
else:
features = features.cpu().numpy()
return features
[docs]
def extract_batch(
self, volumes: List[Union[np.ndarray, torch.Tensor]], batch_size: int = 4
) -> np.ndarray:
"""
Extract features from multiple volumes in batches.
Parameters
----------
volumes : list
List of 3D volumes.
batch_size : int, optional
Batch size for processing.
Returns
-------
numpy.ndarray
(N, Z) array of features.
"""
all_features = []
for i in range(0, len(volumes), batch_size):
batch = volumes[i : i + batch_size]
# Process each volume in batch
batch_features = []
for vol in batch:
features = self.extract(vol)
batch_features.append(features)
all_features.extend(batch_features)
return np.array(all_features)
[docs]
class TransformerBlock3D(nn.Module):
"""3D Transformer block."""
def __init__(self, dim: int, num_heads: int, mlp_ratio: float = 4.0):
"""Initialize a 3D transformer block.
Parameters
----------
dim : int
Embedding dimension.
num_heads : int
Number of attention heads.
mlp_ratio : float, optional
Hidden size multiplier for the MLP, by default 4.0.
"""
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn = nn.MultiheadAttention(dim, num_heads, batch_first=True)
self.norm2 = nn.LayerNorm(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(dim, mlp_hidden_dim), nn.GELU(), nn.Linear(mlp_hidden_dim, dim)
)
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply self-attention and MLP blocks.
Parameters
----------
x : torch.Tensor
Input tensor of shape (B, N, C).
Returns
-------
torch.Tensor
Output tensor of shape (B, N, C).
"""
# Self-attention
x = x + self.attn(self.norm1(x), self.norm1(x), self.norm1(x))[0]
# MLP
x = x + self.mlp(self.norm2(x))
return x
# Complete pipeline: microscope volume -> SAM-Med3D features
[docs]
class MicroscopySAMMed3DPipeline:
"""End-to-end pipeline for microscopy feature extraction."""
def __init__(
self,
sammed3d_path: Optional[str] = None,
device: str = "cuda" if torch.cuda.is_available() else "cpu",
):
"""Initialize the pipeline with a SAM-Med3D extractor.
Parameters
----------
sammed3d_path : Optional[str], optional
Path to a SAM-Med3D checkpoint, by default None.
device : str, optional
Device string for torch execution, by default auto-detected.
"""
self.extractor = SAMMed3DFeatureExtractor(
model_path=sammed3d_path, device=device
)
[docs]
def preprocess_volume(self, volume: np.ndarray) -> np.ndarray:
"""Preprocess microscopy volume."""
# Normalize
volume = (volume - volume.min()) / (volume.max() - volume.min() + 1e-8)
# Optional: apply denoising
from scipy import ndimage
volume = ndimage.gaussian_filter(volume, sigma=0.5)
return volume
[docs]
def extract_features(
self,
volume: np.ndarray,
preprocess: bool = True,
feature_type: str | None = None,
) -> np.ndarray:
"""
Extract features from microscopy volume.
Parameters
----------
volume : numpy.ndarray
3D numpy array (Z, Y, X).
preprocess : bool, optional
Whether to preprocess the volume.
Returns
-------
numpy.ndarray
Feature vector.
"""
if preprocess:
volume = self.preprocess_volume(volume)
features = self.extractor.extract(volume, feature_type=feature_type)
return features
[docs]
def extract_features_batch(
self,
volumes: List[np.ndarray],
preprocess: bool = True,
batch_size: int = 4,
feature_type: str | None = None,
) -> np.ndarray:
"""Extract features from multiple volumes."""
if preprocess:
volumes = [self.preprocess_volume(v) for v in volumes]
features = self.extractor.extract_batch(volumes, batch_size=batch_size)
return features
[docs]
def check_for_zero_objects(label_image: np.ndarray) -> bool:
"""Check if there are any objects in the label image."""
unique_labels = np.unique(label_image)
# Exclude background label (0)
object_labels = unique_labels[unique_labels != 0]
return len(object_labels) == 0
[docs]
def call_SAMMed3D_pipeline(
object_loader: ObjectLoader,
SAMMed3D_model_path: Optional[str] = None,
feature_type: str | List = ["global", "patch", "cls"],
extractor: Optional["MicroscopySAMMed3DPipeline"] = None,
) -> dict:
"""
Call the SAMMed3D pipeline to extract features per patient, well-fov.
Here we call the SAMMed3D pipeline to extract features for each object
in the label image.
Parameters
----------
object_loader : ObjectLoader
Class that loads the image and label image for a given patient,
well-fov, channel, compartment
SAMMed3D_model_path : Optional[str], optional
Path to the SAMMed3D model, by default None. Ignored if extractor is provided.
feature_type : str | List, optional
Feature types to extract, by default ["global", "patch", "cls"]
extractor : Optional[MicroscopySAMMed3DPipeline], optional
Pre-loaded extractor instance. If provided, SAMMed3D_model_path is ignored.
Use this to avoid reloading the model in loops. By default None.
Returns
-------
dict
Dictionary of extracted features from SAMMed3D for each object with keys:
- "object_id": List of object IDs
- "feature_name": List of feature names
- "channel": List of channels
- "compartment": List of compartments
- "value": List of feature values
- "feature_type": List of feature types
"""
assert isinstance(feature_type, (str, list)), (
"feature_type must be a string or list of strings"
)
image_object = object_loader.image
label_object = object_loader.label_image
labels = object_loader.object_ids
ranges = len(labels)
output_dict = {
"object_id": [],
"feature_name": [],
"channel": [],
"compartment": [],
"value": [],
"feature_type": [],
}
if check_for_zero_objects(label_object):
return output_dict
# Use provided extractor or create new one
if extractor is None:
extracter = MicroscopySAMMed3DPipeline(
sammed3d_path=SAMMed3D_model_path,
device="cuda" if torch.cuda.is_available() else "cpu",
)
else:
extracter = extractor
for index, label in enumerate(labels):
selected_label_object = label_object.copy()
selected_image_object = image_object.copy()
selected_label_object[selected_label_object != label] = 0
selected_label_object[selected_label_object > 0] = (
1 # binarize the label for volume calcs
)
selected_image_object[selected_label_object != 1] = 0
if isinstance(feature_type, list):
for ft in feature_type:
features = extracter.extract_features(
selected_image_object, feature_type=ft
) # preprocess the volume
for i, feature_value in enumerate(features.flatten()):
output_dict["object_id"].append(label)
output_dict["feature_name"].append(f"Feature-{ft}{i}")
output_dict["channel"].append(object_loader.channel)
output_dict["compartment"].append(object_loader.compartment)
output_dict["value"].append(feature_value)
output_dict["feature_type"].append("SAMMed3D")
continue
else:
features = extracter.extract_features(
selected_image_object, feature_type=feature_type
) # preprocess the volume
for i, feature_value in enumerate(features.flatten()):
output_dict["object_id"].append(label)
output_dict["feature_name"].append(f"{feature_type}Feature-{ft}{i}")
output_dict["channel"].append(object_loader.channel)
output_dict["compartment"].append(object_loader.compartment)
output_dict["value"].append(feature_value)
output_dict["feature_type"].append("SAMMed3D")
return output_dict
[docs]
def call_whole_image_sammed3d_pipeline(
image: np.ndarray,
SAMMed3D_model_path: Optional[str] = None,
feature_type: str | List = ["global", "patch", "cls"],
extractor: Optional["MicroscopySAMMed3DPipeline"] = None,
) -> dict:
"""
Call the SAMMed3D pipeline to extract features for the whole image.
This function is called per patient, well-fov and extracts features
for the whole FOV volume using the SAMMed3D pipeline.
Parameters
----------
image : np.ndarray
3D numpy array of the image
SAMMed3D_model_path : Optional[str], optional
Path to the SAMMed3D model, by default None. Ignored if extractor is provided.
feature_type : str | List, optional
Type of features to extract, by default ["global", "patch", "cls"]
extractor : Optional[MicroscopySAMMed3DPipeline], optional
Pre-loaded extractor instance. If provided, SAMMed3D_model_path is ignored.
Use this to avoid reloading the model in loops. By default None.
Returns
-------
dict
Dictionary of extracted features from SAMMed3D for the whole image with keys:
- "feature_name": List of feature names
- "value": List of feature values
- "feature_type": List of feature types
- "compartment": List of compartments (will be "Image" for whole image features)
"""
assert isinstance(feature_type, (str, list)), (
"feature_type must be a string or list of strings"
)
output_dict = {
"feature_name": [],
"value": [],
"feature_type": [],
"compartment": [],
}
# Use provided extractor or create new one
if extractor is None:
extracter = MicroscopySAMMed3DPipeline(
sammed3d_path=SAMMed3D_model_path,
device="cuda" if torch.cuda.is_available() else "cpu",
)
else:
extracter = extractor
if isinstance(feature_type, list):
for ft in feature_type:
features = extracter.extract_features(image, feature_type=ft)
for i, feature_value in enumerate(features.flatten()):
output_dict["feature_name"].append(f"{ft}-feature-{i}")
output_dict["value"].append(feature_value)
output_dict["compartment"].append("Image")
output_dict["feature_type"].append("SAMMed3D")
return output_dict
else:
features = extracter.extract_features(image, feature_type=feature_type)
for i, feature_value in enumerate(features.flatten()):
output_dict["feature_name"].append(f"{feature_type}-feature-{i}")
output_dict["value"].append(feature_value)
output_dict["compartment"].append("Image")
output_dict["feature_type"].append("SAMMed3D")
return output_dict