Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/whatsnew_1_6.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
- Nested dot-notation key access in `ConfigParser`.
- Auto3DSeg algo serialization migrated from pickle to JSON for improved security and portability.
- Global coordinates support in spatial crop transforms. These now support global coordinate mode, allowing crops to be specified in world/global coordinates rather than local image indices, improving interoperability with physical-space annotations.
- `GenerateHeatmapd` can convert world-coordinate landmarks to reference-image voxel space and emit landmark visibility masks.
- `SoftclDiceLoss` and `SoftDiceclDiceLoss` enhanced with `DiceLoss`-compatible API
- Variable expansion hardening has been added to the nnUNet app to eliminate code injection attacks when composing shell command lines, addressing concerns in [GHSA-rghg-q7wp-9767](https://github.com/Project-MONAI/MONAI/security/advisories/GHSA-rghg-q7wp-9767).
- `NumpyReader` has been updated with an `allow_pickle` boolean argument to enable/disable pickle loading from `.npy/.npz` files. This was previously hard-coded to be enabled, but is now defined by this argument and disabled by default. This addresses [GHSA-qxq5-qhx6-94qw](https://github.com/Project-MONAI/MONAI/security/advisories/GHSA-qxq5-qhx6-94qw).
Expand Down
109 changes: 105 additions & 4 deletions monai/transforms/post/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
VoteEnsemble,
)
from monai.transforms.transform import MapTransform
from monai.transforms.utility.array import ToTensor
from monai.transforms.utility.array import ApplyTransformToPoints, ToTensor
from monai.transforms.utils import allow_missing_keys_mode, convert_applied_interp_mode
from monai.utils import PostFix, convert_to_tensor, ensure_tuple, ensure_tuple_rep
from monai.utils.type_conversion import convert_to_dst_type
Expand Down Expand Up @@ -527,6 +527,12 @@ class GenerateHeatmapd(MapTransform):
heatmap_keys: keys to store output heatmaps. Default: "{key}_heatmap" for each key.
ref_image_keys: keys of reference images to inherit spatial metadata from. When provided, heatmaps will
have the same shape, affine, and spatial metadata as the reference images.
coordinate_space: coordinate system of the input points. ``"voxel"`` keeps the existing behavior and treats
points as voxel coordinates in the output heatmap space. ``"world"`` transforms points to reference-image
voxel coordinates with ``ref_image_keys`` before generating heatmaps. If the points are a ``MetaTensor``
with their own affine, that affine is used as the point-to-world transform.
visibility_keys: optional keys to store a boolean visibility mask for each point after coordinate conversion.
The value is ``True`` when the transformed point is finite and inside the heatmap spatial shape.
spatial_shape: spatial dimensions of output heatmaps. Can be:
- Single shape (tuple): applied to all keys
- List of shapes: one per key (must match keys length)
Expand All @@ -542,6 +548,7 @@ class GenerateHeatmapd(MapTransform):
ValueError: If heatmap_keys/ref_image_keys length doesn't match keys length.
ValueError: If no spatial shape can be determined (need spatial_shape or ref_image_keys).
ValueError: If input points have invalid shape (must be 2D array with shape (N, D)).
ValueError: If ``coordinate_space="world"`` is used without a reference affine.

Example:
.. code-block:: python
Expand Down Expand Up @@ -573,12 +580,24 @@ class GenerateHeatmapd(MapTransform):
result = transform(data)
# result["landmarks_heatmap"] has shape (2, 64, 64)

# World-space landmarks can be converted against the reference affine.
transform = GenerateHeatmapd(
keys="landmarks_world",
heatmap_keys="landmark_heatmap",
ref_image_keys="image",
coordinate_space="world",
visibility_keys="landmark_visible",
sigma=2.0,
)
Comment thread
coderabbitai[bot] marked this conversation as resolved.

Notes:
- Default heatmap_keys are generated as "{key}_heatmap" for each input key
- Shape inference precedence: static spatial_shape > ref_image
- Input points shape: (N, D) where N is number of landmarks, D is spatial dimensions
- Output heatmap shape: (N, H, W) for 2D or (N, H, W, D) for 3D
- When using ref_image_keys, heatmaps inherit affine and spatial metadata from reference
- ``coordinate_space="world"`` assumes that the points and reference affine use the same world-coordinate
convention. Convert LPS/RAS conventions before calling this transform if needed.
"""

backend = GenerateHeatmap.backend
Expand All @@ -590,13 +609,20 @@ class GenerateHeatmapd(MapTransform):
_ERR_NO_SHAPE = "Unable to determine spatial shape for GenerateHeatmapd. Provide spatial_shape or ref_image_keys."
_ERR_INVALID_POINTS = "Landmark arrays must be 2D with shape (N, D)."
_ERR_REF_NO_SHAPE = "Reference data must define a shape attribute."
_ERR_VISIBILITY_KEYS_LEN = "Argument `visibility_keys` length must match keys length when provided."
_ERR_COORDINATE_SPACE_LEN = (
"Argument `coordinate_space` length must match keys length when providing per-key values."
)
_SUPPORTED_COORDINATE_SPACES = {"voxel", "world"}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📐 Maintainability & Code Quality | 🟡 Minor | ⚡ Quick win

Make the coordinate-space set immutable.

Ruff flags this mutable class attribute. Use frozenset.

Proposed fix
-    _SUPPORTED_COORDINATE_SPACES = {"voxel", "world"}
+    _SUPPORTED_COORDINATE_SPACES = frozenset({"voxel", "world"})
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
_SUPPORTED_COORDINATE_SPACES = {"voxel", "world"}
_SUPPORTED_COORDINATE_SPACES = frozenset({"voxel", "world"})
🧰 Tools
🪛 Ruff (0.15.18)

[warning] 616-616: Mutable default value for class attribute

(RUF012)

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@monai/transforms/post/dictionary.py` at line 616, The
_SUPPORTED_COORDINATE_SPACES class attribute in the coordinate-space transform
code is a mutable set and should be made immutable. Update the attribute to use
frozenset in the same class/module where it is defined so Ruff no longer flags
it, keeping the existing "voxel" and "world" values unchanged.

Source: Linters/SAST tools


def __init__(
self,
keys: KeysCollection,
sigma: Sequence[float] | float = 5.0,
heatmap_keys: KeysCollection | None = None,
ref_image_keys: KeysCollection | None = None,
coordinate_space: str | Sequence[str] = "voxel",
visibility_keys: KeysCollection | None = None,
spatial_shape: Sequence[int] | Sequence[Sequence[int]] | None = None,
truncated: float = 4.0,
normalize: bool = True,
Expand All @@ -606,22 +632,27 @@ def __init__(
super().__init__(keys, allow_missing_keys)
self.heatmap_keys = self._prepare_heatmap_keys(heatmap_keys)
self.ref_image_keys = self._prepare_optional_keys(ref_image_keys)
self.coordinate_spaces = self._prepare_coordinate_spaces(coordinate_space)
self.visibility_keys = self._prepare_visibility_keys(visibility_keys)
self.static_shapes = self._prepare_shapes(spatial_shape)
self.generator = GenerateHeatmap(
sigma=sigma, spatial_shape=None, truncated=truncated, normalize=normalize, dtype=dtype
)
self.world_to_voxel = ApplyTransformToPoints(dtype=torch.float32, invert_affine=True)

def __call__(self, data: Mapping[Hashable, Any]) -> dict[Hashable, Any]:
d = dict(data)
for key, out_key, ref_key, static_shape in self.key_iterator(
d, self.heatmap_keys, self.ref_image_keys, self.static_shapes
for key, out_key, ref_key, coordinate_space, visibility_key, static_shape in self.key_iterator(
d, self.heatmap_keys, self.ref_image_keys, self.coordinate_spaces, self.visibility_keys, self.static_shapes
):
points = d[key]
shape = self._determine_shape(points, static_shape, d, ref_key)
reference = d.get(ref_key) if ref_key is not None and ref_key in d else None
points = self._convert_points(points, reference, coordinate_space)
visibility = self._compute_visibility(points, shape)
# The GenerateHeatmap transform will handle type conversion based on input points
heatmap = self.generator(points, spatial_shape=shape)
# If there's a reference image and we need to match its type/device
reference = d.get(ref_key) if ref_key is not None and ref_key in d else None
if reference is not None and isinstance(reference, (torch.Tensor, np.ndarray)):
# Convert to match reference type and device while preserving heatmap's dtype
heatmap, _, _ = convert_to_dst_type(
Expand All @@ -632,6 +663,8 @@ def __call__(self, data: Mapping[Hashable, Any]) -> dict[Hashable, Any]:
heatmap.affine = reference.affine
self._update_spatial_metadata(heatmap, shape)
d[out_key] = heatmap
if visibility_key is not None:
d[visibility_key] = self._convert_visibility(visibility, d[key])
return d

def _prepare_heatmap_keys(self, heatmap_keys: KeysCollection | None) -> tuple[Hashable, ...]:
Expand All @@ -654,6 +687,34 @@ def _prepare_optional_keys(self, maybe_keys: KeysCollection | None) -> tuple[Has
raise ValueError(self._ERR_REF_KEYS_LEN)
return tuple(keys_tuple)

def _prepare_visibility_keys(self, maybe_keys: KeysCollection | None) -> tuple[Hashable | None, ...]:
if maybe_keys is None:
return (None,) * len(self.keys)
keys_tuple = ensure_tuple(maybe_keys)
if len(keys_tuple) == 1 and len(self.keys) > 1:
keys_tuple = keys_tuple * len(self.keys)
if len(keys_tuple) != len(self.keys):
raise ValueError(self._ERR_VISIBILITY_KEYS_LEN)
return tuple(keys_tuple)

def _prepare_coordinate_spaces(self, coordinate_space: str | Sequence[str]) -> tuple[str, ...]:
if isinstance(coordinate_space, str):
spaces = (coordinate_space,) * len(self.keys)
else:
spaces = ensure_tuple(coordinate_space)
if len(spaces) == 1 and len(self.keys) > 1:
spaces = spaces * len(self.keys)
if len(spaces) != len(self.keys):
raise ValueError(self._ERR_COORDINATE_SPACE_LEN)
spaces = tuple(str(space).lower() for space in spaces)
invalid = set(spaces) - self._SUPPORTED_COORDINATE_SPACES
if invalid:
raise ValueError(
f"Unsupported coordinate_space value: {sorted(invalid)}. "
f"Supported values are {sorted(self._SUPPORTED_COORDINATE_SPACES)}."
)
return spaces

def _prepare_shapes(
self, spatial_shape: Sequence[int] | Sequence[Sequence[int]] | None
) -> tuple[tuple[int, ...] | None, ...]:
Expand Down Expand Up @@ -711,6 +772,46 @@ def _update_spatial_metadata(self, heatmap: MetaTensor, spatial_shape: tuple[int
"""Set spatial_shape explicitly from resolved shape."""
heatmap.meta["spatial_shape"] = tuple(int(v) for v in spatial_shape)

def _convert_points(self, points: Any, reference: Any, coordinate_space: str) -> Any:
if coordinate_space == "voxel":
return points

affine = self._get_reference_affine(reference)
points_t = convert_to_tensor(points, dtype=torch.float32, track_meta=False)
if points_t.ndim != 2:
raise ValueError(f"{self._ERR_INVALID_POINTS} Got {points_t.ndim}D tensor.")

if isinstance(points, MetaTensor):
points_to_transform = points.unsqueeze(0)
else:
points_to_transform = points_t.unsqueeze(0)
converted = self.world_to_voxel(points_to_transform, affine).squeeze(0)
return converted

def _get_reference_affine(self, reference: Any) -> torch.Tensor:
if reference is None:
raise ValueError("coordinate_space='world' requires ref_image_keys or a reference affine.")
affine = getattr(reference, "affine", None)
if affine is not None:
return affine
if isinstance(reference, (torch.Tensor, np.ndarray)) and reference.shape in ((3, 3), (4, 4)):
return reference
raise ValueError("coordinate_space='world' requires reference data with an affine matrix.")
Comment on lines +791 to +799

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎯 Functional Correctness | 🟠 Major | ⚡ Quick win

Do not infer affines from raw 3x3/4x4 images.

A normal 2D reference image with shape (3, 3) or (4, 4) will be treated as an affine and inverted, causing wrong coordinates or a singular-matrix crash instead of the documented missing-affine error.

Safer fallback
         affine = getattr(reference, "affine", None)
         if affine is not None:
             return affine
-        if isinstance(reference, (torch.Tensor, np.ndarray)) and reference.shape in ((3, 3), (4, 4)):
-            return reference
         raise ValueError("coordinate_space='world' requires reference data with an affine matrix.")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def _get_reference_affine(self, reference: Any) -> torch.Tensor:
if reference is None:
raise ValueError("coordinate_space='world' requires ref_image_keys or a reference affine.")
affine = getattr(reference, "affine", None)
if affine is not None:
return affine
if isinstance(reference, (torch.Tensor, np.ndarray)) and reference.shape in ((3, 3), (4, 4)):
return reference
raise ValueError("coordinate_space='world' requires reference data with an affine matrix.")
def _get_reference_affine(self, reference: Any) -> torch.Tensor:
if reference is None:
raise ValueError("coordinate_space='world' requires ref_image_keys or a reference affine.")
affine = getattr(reference, "affine", None)
if affine is not None:
return affine
raise ValueError("coordinate_space='world' requires reference data with an affine matrix.")
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@monai/transforms/post/dictionary.py` around lines 796 - 804, The
_get_reference_affine helper in the dictionary transform is incorrectly treating
any raw 3x3/4x4 tensor or ndarray as an affine, which can misclassify normal
reference images. Update this logic so only objects with an explicit affine
attribute are accepted for coordinate_space='world', and otherwise raise the
documented missing-affine ValueError instead of inferring from shape; keep the
fix localized to _get_reference_affine and its callers.


def _compute_visibility(self, points: Any, spatial_shape: tuple[int, ...]) -> torch.Tensor:
points_t = convert_to_tensor(points, dtype=torch.float32, track_meta=False)
if points_t.ndim != 2:
raise ValueError(f"{self._ERR_INVALID_POINTS} Got {points_t.ndim}D tensor.")
bounds = torch.as_tensor(spatial_shape, dtype=points_t.dtype, device=points_t.device)
return torch.isfinite(points_t).all(dim=1) & (points_t >= 0).all(dim=1) & (points_t < bounds).all(dim=1)

def _convert_visibility(self, visibility: torch.Tensor, points: Any) -> NdarrayOrTensor:
if isinstance(points, (MetaTensor, torch.Tensor)):
return visibility.to(device=points.device, dtype=torch.bool)
if isinstance(points, np.ndarray):
return visibility.cpu().numpy().astype(bool)
return visibility.to(dtype=torch.bool)


GenerateHeatmapD = GenerateHeatmapDict = GenerateHeatmapd

Expand Down
102 changes: 102 additions & 0 deletions tests/transforms/test_generate_heatmapd.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@
from monai.transforms.post.dictionary import GenerateHeatmapd
from tests.test_utils import assert_allclose


def _peak_coord(channel: torch.Tensor) -> torch.Tensor:
idx = torch.argmax(channel)
return torch.stack(torch.unravel_index(idx, channel.shape))


# Test cases for dictionary transforms with reference image
# Only test with non-MetaTensor types to avoid affine conflicts
TEST_CASES_WITH_REF = [
Expand Down Expand Up @@ -220,6 +226,102 @@ def test_metatensor_points_with_ref(self):
# Heatmap should inherit affine from the reference image
assert_allclose(heatmap.affine, image.affine, type_test=False)

def test_world_points_with_reference_affine_and_visibility(self):
affine = torch.diag(torch.tensor([2.0, 2.0, 2.0, 1.0]))
image = MetaTensor(torch.zeros((1, 8, 8, 8), dtype=torch.float32), affine=affine)
image.meta["spatial_shape"] = (8, 8, 8)
points = torch.tensor(
[
[4.0, 6.0, 8.0], # voxel coordinate [2, 3, 4]
[20.0, 0.0, 0.0], # out of bounds after affine conversion
[float("nan"), 0.0, 0.0],
],
dtype=torch.float32,
)

transform = GenerateHeatmapd(
keys="points",
heatmap_keys="heatmap",
ref_image_keys="image",
coordinate_space="world",
visibility_keys="visible",
sigma=1.0,
)
result = transform({"points": points, "image": image})

heatmap = result["heatmap"]
self.assertIsInstance(heatmap, MetaTensor)
self.assertEqual(tuple(heatmap.shape), (3, 8, 8, 8))
assert_allclose(_peak_coord(heatmap[0]), torch.tensor([2, 3, 4]), type_test=False)
self.assertTrue(torch.equal(result["visible"], torch.tensor([True, False, False])))
self.assertGreater(heatmap[0].max(), 0.99)
self.assertEqual(float(heatmap[1].max()), 0.0)
self.assertEqual(float(heatmap[2].max()), 0.0)

def test_world_points_with_translated_rotated_affine(self):
affine = torch.tensor(
[[0.0, -2.0, 0.0, 10.0], [3.0, 0.0, 0.0, 20.0], [0.0, 0.0, 4.0, 30.0], [0.0, 0.0, 0.0, 1.0]],
dtype=torch.float32,
)
image = MetaTensor(torch.zeros((1, 8, 8, 8), dtype=torch.float32), affine=affine)
image.meta["spatial_shape"] = (8, 8, 8)
voxel_point = torch.tensor([2.0, 3.0, 4.0], dtype=torch.float32)
world_point = affine[:3, :3] @ voxel_point + affine[:3, 3]

transform = GenerateHeatmapd(
keys="points",
heatmap_keys="heatmap",
ref_image_keys="image",
coordinate_space="world",
visibility_keys="visible",
sigma=1.0,
)
result = transform({"points": world_point[None], "image": image})

assert_allclose(_peak_coord(result["heatmap"][0]), voxel_point.to(torch.long), type_test=False)
self.assertTrue(torch.equal(result["visible"], torch.tensor([True])))

def test_world_metatensor_points_use_point_affine(self):
image = MetaTensor(torch.zeros((1, 8, 8, 8), dtype=torch.float32), affine=torch.eye(4))
image.meta["spatial_shape"] = (8, 8, 8)
points_affine = torch.diag(torch.tensor([2.0, 2.0, 2.0, 1.0]))
points = MetaTensor(torch.tensor([[1.0, 2.0, 3.0]], dtype=torch.float32), affine=points_affine)

transform = GenerateHeatmapd(
keys="points",
heatmap_keys="heatmap",
ref_image_keys="image",
coordinate_space="world",
visibility_keys="visible",
sigma=1.0,
)
result = transform({"points": points, "image": image})

assert_allclose(_peak_coord(result["heatmap"][0]), torch.tensor([2, 4, 6]), type_test=False)
self.assertIsInstance(result["visible"], torch.Tensor)
self.assertNotIsInstance(result["visible"], MetaTensor)
self.assertTrue(bool(result["visible"][0]))

def test_world_points_require_reference_affine(self):
transform = GenerateHeatmapd(
keys="points", heatmap_keys="heatmap", spatial_shape=(8, 8, 8), coordinate_space="world"
)
with self.assertRaisesRegex(ValueError, "reference|affine|ref_image_keys"):
transform({"points": torch.zeros((1, 3), dtype=torch.float32)})

def test_invalid_coordinate_space_raises(self):
with self.assertRaisesRegex(ValueError, "coordinate_space"):
GenerateHeatmapd(keys="points", heatmap_keys="heatmap", spatial_shape=(8, 8), coordinate_space="scanner")

def test_visibility_key_length_mismatch_raises(self):
with self.assertRaises(ValueError):
GenerateHeatmapd(
keys=["pts1", "pts2"],
heatmap_keys=["hm1", "hm2"],
visibility_keys=["visible1", "visible2", "visible3"],
spatial_shape=(8, 8),
)


if __name__ == "__main__":
unittest.main()
Loading