-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Add affine-aware landmark heatmap generation #8957
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -46,7 +46,7 @@ | |||||||||||||||||||||||||||||||||
| VoteEnsemble, | ||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||
| from monai.transforms.transform import MapTransform | ||||||||||||||||||||||||||||||||||
| from monai.transforms.utility.array import ToTensor | ||||||||||||||||||||||||||||||||||
| from monai.transforms.utility.array import ApplyTransformToPoints, ToTensor | ||||||||||||||||||||||||||||||||||
| from monai.transforms.utils import allow_missing_keys_mode, convert_applied_interp_mode | ||||||||||||||||||||||||||||||||||
| from monai.utils import PostFix, convert_to_tensor, ensure_tuple, ensure_tuple_rep | ||||||||||||||||||||||||||||||||||
| from monai.utils.type_conversion import convert_to_dst_type | ||||||||||||||||||||||||||||||||||
|
|
@@ -527,6 +527,12 @@ class GenerateHeatmapd(MapTransform): | |||||||||||||||||||||||||||||||||
| heatmap_keys: keys to store output heatmaps. Default: "{key}_heatmap" for each key. | ||||||||||||||||||||||||||||||||||
| ref_image_keys: keys of reference images to inherit spatial metadata from. When provided, heatmaps will | ||||||||||||||||||||||||||||||||||
| have the same shape, affine, and spatial metadata as the reference images. | ||||||||||||||||||||||||||||||||||
| coordinate_space: coordinate system of the input points. ``"voxel"`` keeps the existing behavior and treats | ||||||||||||||||||||||||||||||||||
| points as voxel coordinates in the output heatmap space. ``"world"`` transforms points to reference-image | ||||||||||||||||||||||||||||||||||
| voxel coordinates with ``ref_image_keys`` before generating heatmaps. If the points are a ``MetaTensor`` | ||||||||||||||||||||||||||||||||||
| with their own affine, that affine is used as the point-to-world transform. | ||||||||||||||||||||||||||||||||||
| visibility_keys: optional keys to store a boolean visibility mask for each point after coordinate conversion. | ||||||||||||||||||||||||||||||||||
| The value is ``True`` when the transformed point is finite and inside the heatmap spatial shape. | ||||||||||||||||||||||||||||||||||
| spatial_shape: spatial dimensions of output heatmaps. Can be: | ||||||||||||||||||||||||||||||||||
| - Single shape (tuple): applied to all keys | ||||||||||||||||||||||||||||||||||
| - List of shapes: one per key (must match keys length) | ||||||||||||||||||||||||||||||||||
|
|
@@ -542,6 +548,7 @@ class GenerateHeatmapd(MapTransform): | |||||||||||||||||||||||||||||||||
| ValueError: If heatmap_keys/ref_image_keys length doesn't match keys length. | ||||||||||||||||||||||||||||||||||
| ValueError: If no spatial shape can be determined (need spatial_shape or ref_image_keys). | ||||||||||||||||||||||||||||||||||
| ValueError: If input points have invalid shape (must be 2D array with shape (N, D)). | ||||||||||||||||||||||||||||||||||
| ValueError: If ``coordinate_space="world"`` is used without a reference affine. | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| Example: | ||||||||||||||||||||||||||||||||||
| .. code-block:: python | ||||||||||||||||||||||||||||||||||
|
|
@@ -573,12 +580,24 @@ class GenerateHeatmapd(MapTransform): | |||||||||||||||||||||||||||||||||
| result = transform(data) | ||||||||||||||||||||||||||||||||||
| # result["landmarks_heatmap"] has shape (2, 64, 64) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| # World-space landmarks can be converted against the reference affine. | ||||||||||||||||||||||||||||||||||
| transform = GenerateHeatmapd( | ||||||||||||||||||||||||||||||||||
| keys="landmarks_world", | ||||||||||||||||||||||||||||||||||
| heatmap_keys="landmark_heatmap", | ||||||||||||||||||||||||||||||||||
| ref_image_keys="image", | ||||||||||||||||||||||||||||||||||
| coordinate_space="world", | ||||||||||||||||||||||||||||||||||
| visibility_keys="landmark_visible", | ||||||||||||||||||||||||||||||||||
| sigma=2.0, | ||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| Notes: | ||||||||||||||||||||||||||||||||||
| - Default heatmap_keys are generated as "{key}_heatmap" for each input key | ||||||||||||||||||||||||||||||||||
| - Shape inference precedence: static spatial_shape > ref_image | ||||||||||||||||||||||||||||||||||
| - Input points shape: (N, D) where N is number of landmarks, D is spatial dimensions | ||||||||||||||||||||||||||||||||||
| - Output heatmap shape: (N, H, W) for 2D or (N, H, W, D) for 3D | ||||||||||||||||||||||||||||||||||
| - When using ref_image_keys, heatmaps inherit affine and spatial metadata from reference | ||||||||||||||||||||||||||||||||||
| - ``coordinate_space="world"`` assumes that the points and reference affine use the same world-coordinate | ||||||||||||||||||||||||||||||||||
| convention. Convert LPS/RAS conventions before calling this transform if needed. | ||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| backend = GenerateHeatmap.backend | ||||||||||||||||||||||||||||||||||
|
|
@@ -590,13 +609,20 @@ class GenerateHeatmapd(MapTransform): | |||||||||||||||||||||||||||||||||
| _ERR_NO_SHAPE = "Unable to determine spatial shape for GenerateHeatmapd. Provide spatial_shape or ref_image_keys." | ||||||||||||||||||||||||||||||||||
| _ERR_INVALID_POINTS = "Landmark arrays must be 2D with shape (N, D)." | ||||||||||||||||||||||||||||||||||
| _ERR_REF_NO_SHAPE = "Reference data must define a shape attribute." | ||||||||||||||||||||||||||||||||||
| _ERR_VISIBILITY_KEYS_LEN = "Argument `visibility_keys` length must match keys length when provided." | ||||||||||||||||||||||||||||||||||
| _ERR_COORDINATE_SPACE_LEN = ( | ||||||||||||||||||||||||||||||||||
| "Argument `coordinate_space` length must match keys length when providing per-key values." | ||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||
| _SUPPORTED_COORDINATE_SPACES = {"voxel", "world"} | ||||||||||||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 📐 Maintainability & Code Quality | 🟡 Minor | ⚡ Quick win Make the coordinate-space set immutable. Ruff flags this mutable class attribute. Use Proposed fix- _SUPPORTED_COORDINATE_SPACES = {"voxel", "world"}
+ _SUPPORTED_COORDINATE_SPACES = frozenset({"voxel", "world"})📝 Committable suggestion
Suggested change
🧰 Tools🪛 Ruff (0.15.18)[warning] 616-616: Mutable default value for class attribute (RUF012) 🤖 Prompt for AI AgentsSource: Linters/SAST tools |
||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| def __init__( | ||||||||||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||||||||||
| keys: KeysCollection, | ||||||||||||||||||||||||||||||||||
| sigma: Sequence[float] | float = 5.0, | ||||||||||||||||||||||||||||||||||
| heatmap_keys: KeysCollection | None = None, | ||||||||||||||||||||||||||||||||||
| ref_image_keys: KeysCollection | None = None, | ||||||||||||||||||||||||||||||||||
| coordinate_space: str | Sequence[str] = "voxel", | ||||||||||||||||||||||||||||||||||
| visibility_keys: KeysCollection | None = None, | ||||||||||||||||||||||||||||||||||
| spatial_shape: Sequence[int] | Sequence[Sequence[int]] | None = None, | ||||||||||||||||||||||||||||||||||
| truncated: float = 4.0, | ||||||||||||||||||||||||||||||||||
| normalize: bool = True, | ||||||||||||||||||||||||||||||||||
|
|
@@ -606,22 +632,27 @@ def __init__( | |||||||||||||||||||||||||||||||||
| super().__init__(keys, allow_missing_keys) | ||||||||||||||||||||||||||||||||||
| self.heatmap_keys = self._prepare_heatmap_keys(heatmap_keys) | ||||||||||||||||||||||||||||||||||
| self.ref_image_keys = self._prepare_optional_keys(ref_image_keys) | ||||||||||||||||||||||||||||||||||
| self.coordinate_spaces = self._prepare_coordinate_spaces(coordinate_space) | ||||||||||||||||||||||||||||||||||
| self.visibility_keys = self._prepare_visibility_keys(visibility_keys) | ||||||||||||||||||||||||||||||||||
| self.static_shapes = self._prepare_shapes(spatial_shape) | ||||||||||||||||||||||||||||||||||
| self.generator = GenerateHeatmap( | ||||||||||||||||||||||||||||||||||
| sigma=sigma, spatial_shape=None, truncated=truncated, normalize=normalize, dtype=dtype | ||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||
| self.world_to_voxel = ApplyTransformToPoints(dtype=torch.float32, invert_affine=True) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| def __call__(self, data: Mapping[Hashable, Any]) -> dict[Hashable, Any]: | ||||||||||||||||||||||||||||||||||
| d = dict(data) | ||||||||||||||||||||||||||||||||||
| for key, out_key, ref_key, static_shape in self.key_iterator( | ||||||||||||||||||||||||||||||||||
| d, self.heatmap_keys, self.ref_image_keys, self.static_shapes | ||||||||||||||||||||||||||||||||||
| for key, out_key, ref_key, coordinate_space, visibility_key, static_shape in self.key_iterator( | ||||||||||||||||||||||||||||||||||
| d, self.heatmap_keys, self.ref_image_keys, self.coordinate_spaces, self.visibility_keys, self.static_shapes | ||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||
| points = d[key] | ||||||||||||||||||||||||||||||||||
| shape = self._determine_shape(points, static_shape, d, ref_key) | ||||||||||||||||||||||||||||||||||
| reference = d.get(ref_key) if ref_key is not None and ref_key in d else None | ||||||||||||||||||||||||||||||||||
| points = self._convert_points(points, reference, coordinate_space) | ||||||||||||||||||||||||||||||||||
| visibility = self._compute_visibility(points, shape) | ||||||||||||||||||||||||||||||||||
| # The GenerateHeatmap transform will handle type conversion based on input points | ||||||||||||||||||||||||||||||||||
| heatmap = self.generator(points, spatial_shape=shape) | ||||||||||||||||||||||||||||||||||
| # If there's a reference image and we need to match its type/device | ||||||||||||||||||||||||||||||||||
| reference = d.get(ref_key) if ref_key is not None and ref_key in d else None | ||||||||||||||||||||||||||||||||||
| if reference is not None and isinstance(reference, (torch.Tensor, np.ndarray)): | ||||||||||||||||||||||||||||||||||
| # Convert to match reference type and device while preserving heatmap's dtype | ||||||||||||||||||||||||||||||||||
| heatmap, _, _ = convert_to_dst_type( | ||||||||||||||||||||||||||||||||||
|
|
@@ -632,6 +663,8 @@ def __call__(self, data: Mapping[Hashable, Any]) -> dict[Hashable, Any]: | |||||||||||||||||||||||||||||||||
| heatmap.affine = reference.affine | ||||||||||||||||||||||||||||||||||
| self._update_spatial_metadata(heatmap, shape) | ||||||||||||||||||||||||||||||||||
| d[out_key] = heatmap | ||||||||||||||||||||||||||||||||||
| if visibility_key is not None: | ||||||||||||||||||||||||||||||||||
| d[visibility_key] = self._convert_visibility(visibility, d[key]) | ||||||||||||||||||||||||||||||||||
| return d | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| def _prepare_heatmap_keys(self, heatmap_keys: KeysCollection | None) -> tuple[Hashable, ...]: | ||||||||||||||||||||||||||||||||||
|
|
@@ -654,6 +687,34 @@ def _prepare_optional_keys(self, maybe_keys: KeysCollection | None) -> tuple[Has | |||||||||||||||||||||||||||||||||
| raise ValueError(self._ERR_REF_KEYS_LEN) | ||||||||||||||||||||||||||||||||||
| return tuple(keys_tuple) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| def _prepare_visibility_keys(self, maybe_keys: KeysCollection | None) -> tuple[Hashable | None, ...]: | ||||||||||||||||||||||||||||||||||
| if maybe_keys is None: | ||||||||||||||||||||||||||||||||||
| return (None,) * len(self.keys) | ||||||||||||||||||||||||||||||||||
| keys_tuple = ensure_tuple(maybe_keys) | ||||||||||||||||||||||||||||||||||
| if len(keys_tuple) == 1 and len(self.keys) > 1: | ||||||||||||||||||||||||||||||||||
| keys_tuple = keys_tuple * len(self.keys) | ||||||||||||||||||||||||||||||||||
| if len(keys_tuple) != len(self.keys): | ||||||||||||||||||||||||||||||||||
| raise ValueError(self._ERR_VISIBILITY_KEYS_LEN) | ||||||||||||||||||||||||||||||||||
| return tuple(keys_tuple) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| def _prepare_coordinate_spaces(self, coordinate_space: str | Sequence[str]) -> tuple[str, ...]: | ||||||||||||||||||||||||||||||||||
| if isinstance(coordinate_space, str): | ||||||||||||||||||||||||||||||||||
| spaces = (coordinate_space,) * len(self.keys) | ||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||
| spaces = ensure_tuple(coordinate_space) | ||||||||||||||||||||||||||||||||||
| if len(spaces) == 1 and len(self.keys) > 1: | ||||||||||||||||||||||||||||||||||
| spaces = spaces * len(self.keys) | ||||||||||||||||||||||||||||||||||
| if len(spaces) != len(self.keys): | ||||||||||||||||||||||||||||||||||
| raise ValueError(self._ERR_COORDINATE_SPACE_LEN) | ||||||||||||||||||||||||||||||||||
| spaces = tuple(str(space).lower() for space in spaces) | ||||||||||||||||||||||||||||||||||
| invalid = set(spaces) - self._SUPPORTED_COORDINATE_SPACES | ||||||||||||||||||||||||||||||||||
| if invalid: | ||||||||||||||||||||||||||||||||||
| raise ValueError( | ||||||||||||||||||||||||||||||||||
| f"Unsupported coordinate_space value: {sorted(invalid)}. " | ||||||||||||||||||||||||||||||||||
| f"Supported values are {sorted(self._SUPPORTED_COORDINATE_SPACES)}." | ||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||
| return spaces | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| def _prepare_shapes( | ||||||||||||||||||||||||||||||||||
| self, spatial_shape: Sequence[int] | Sequence[Sequence[int]] | None | ||||||||||||||||||||||||||||||||||
| ) -> tuple[tuple[int, ...] | None, ...]: | ||||||||||||||||||||||||||||||||||
|
|
@@ -711,6 +772,46 @@ def _update_spatial_metadata(self, heatmap: MetaTensor, spatial_shape: tuple[int | |||||||||||||||||||||||||||||||||
| """Set spatial_shape explicitly from resolved shape.""" | ||||||||||||||||||||||||||||||||||
| heatmap.meta["spatial_shape"] = tuple(int(v) for v in spatial_shape) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| def _convert_points(self, points: Any, reference: Any, coordinate_space: str) -> Any: | ||||||||||||||||||||||||||||||||||
| if coordinate_space == "voxel": | ||||||||||||||||||||||||||||||||||
| return points | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| affine = self._get_reference_affine(reference) | ||||||||||||||||||||||||||||||||||
| points_t = convert_to_tensor(points, dtype=torch.float32, track_meta=False) | ||||||||||||||||||||||||||||||||||
| if points_t.ndim != 2: | ||||||||||||||||||||||||||||||||||
| raise ValueError(f"{self._ERR_INVALID_POINTS} Got {points_t.ndim}D tensor.") | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| if isinstance(points, MetaTensor): | ||||||||||||||||||||||||||||||||||
| points_to_transform = points.unsqueeze(0) | ||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||
| points_to_transform = points_t.unsqueeze(0) | ||||||||||||||||||||||||||||||||||
| converted = self.world_to_voxel(points_to_transform, affine).squeeze(0) | ||||||||||||||||||||||||||||||||||
| return converted | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| def _get_reference_affine(self, reference: Any) -> torch.Tensor: | ||||||||||||||||||||||||||||||||||
| if reference is None: | ||||||||||||||||||||||||||||||||||
| raise ValueError("coordinate_space='world' requires ref_image_keys or a reference affine.") | ||||||||||||||||||||||||||||||||||
| affine = getattr(reference, "affine", None) | ||||||||||||||||||||||||||||||||||
| if affine is not None: | ||||||||||||||||||||||||||||||||||
| return affine | ||||||||||||||||||||||||||||||||||
| if isinstance(reference, (torch.Tensor, np.ndarray)) and reference.shape in ((3, 3), (4, 4)): | ||||||||||||||||||||||||||||||||||
| return reference | ||||||||||||||||||||||||||||||||||
| raise ValueError("coordinate_space='world' requires reference data with an affine matrix.") | ||||||||||||||||||||||||||||||||||
|
Comment on lines
+791
to
+799
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🎯 Functional Correctness | 🟠 Major | ⚡ Quick win Do not infer affines from raw 3x3/4x4 images. A normal 2D reference image with shape Safer fallback affine = getattr(reference, "affine", None)
if affine is not None:
return affine
- if isinstance(reference, (torch.Tensor, np.ndarray)) and reference.shape in ((3, 3), (4, 4)):
- return reference
raise ValueError("coordinate_space='world' requires reference data with an affine matrix.")📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| def _compute_visibility(self, points: Any, spatial_shape: tuple[int, ...]) -> torch.Tensor: | ||||||||||||||||||||||||||||||||||
| points_t = convert_to_tensor(points, dtype=torch.float32, track_meta=False) | ||||||||||||||||||||||||||||||||||
| if points_t.ndim != 2: | ||||||||||||||||||||||||||||||||||
| raise ValueError(f"{self._ERR_INVALID_POINTS} Got {points_t.ndim}D tensor.") | ||||||||||||||||||||||||||||||||||
| bounds = torch.as_tensor(spatial_shape, dtype=points_t.dtype, device=points_t.device) | ||||||||||||||||||||||||||||||||||
| return torch.isfinite(points_t).all(dim=1) & (points_t >= 0).all(dim=1) & (points_t < bounds).all(dim=1) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| def _convert_visibility(self, visibility: torch.Tensor, points: Any) -> NdarrayOrTensor: | ||||||||||||||||||||||||||||||||||
| if isinstance(points, (MetaTensor, torch.Tensor)): | ||||||||||||||||||||||||||||||||||
| return visibility.to(device=points.device, dtype=torch.bool) | ||||||||||||||||||||||||||||||||||
| if isinstance(points, np.ndarray): | ||||||||||||||||||||||||||||||||||
| return visibility.cpu().numpy().astype(bool) | ||||||||||||||||||||||||||||||||||
| return visibility.to(dtype=torch.bool) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| GenerateHeatmapD = GenerateHeatmapDict = GenerateHeatmapd | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
Uh oh!
There was an error while loading. Please reload this page.