diff --git a/docs/source/whatsnew_1_6.md b/docs/source/whatsnew_1_6.md index e4e43b0307..170dd3cfd6 100644 --- a/docs/source/whatsnew_1_6.md +++ b/docs/source/whatsnew_1_6.md @@ -9,6 +9,7 @@ - Nested dot-notation key access in `ConfigParser`. - Auto3DSeg algo serialization migrated from pickle to JSON for improved security and portability. - Global coordinates support in spatial crop transforms. These now support global coordinate mode, allowing crops to be specified in world/global coordinates rather than local image indices, improving interoperability with physical-space annotations. +- `GenerateHeatmapd` can convert world-coordinate landmarks to reference-image voxel space and emit landmark visibility masks. - `SoftclDiceLoss` and `SoftDiceclDiceLoss` enhanced with `DiceLoss`-compatible API - Variable expansion hardening has been added to the nnUNet app to eliminate code injection attacks when composing shell command lines, addressing concerns in [GHSA-rghg-q7wp-9767](https://github.com/Project-MONAI/MONAI/security/advisories/GHSA-rghg-q7wp-9767). - `NumpyReader` has been updated with an `allow_pickle` boolean argument to enable/disable pickle loading from `.npy/.npz` files. This was previously hard-coded to be enabled, but is now defined by this argument and disabled by default. This addresses [GHSA-qxq5-qhx6-94qw](https://github.com/Project-MONAI/MONAI/security/advisories/GHSA-qxq5-qhx6-94qw). diff --git a/monai/transforms/post/dictionary.py b/monai/transforms/post/dictionary.py index 65fdd22b22..9e7f2c5928 100644 --- a/monai/transforms/post/dictionary.py +++ b/monai/transforms/post/dictionary.py @@ -46,7 +46,7 @@ VoteEnsemble, ) from monai.transforms.transform import MapTransform -from monai.transforms.utility.array import ToTensor +from monai.transforms.utility.array import ApplyTransformToPoints, ToTensor from monai.transforms.utils import allow_missing_keys_mode, convert_applied_interp_mode from monai.utils import PostFix, convert_to_tensor, ensure_tuple, ensure_tuple_rep from monai.utils.type_conversion import convert_to_dst_type @@ -527,6 +527,12 @@ class GenerateHeatmapd(MapTransform): heatmap_keys: keys to store output heatmaps. Default: "{key}_heatmap" for each key. ref_image_keys: keys of reference images to inherit spatial metadata from. When provided, heatmaps will have the same shape, affine, and spatial metadata as the reference images. + coordinate_space: coordinate system of the input points. ``"voxel"`` keeps the existing behavior and treats + points as voxel coordinates in the output heatmap space. ``"world"`` transforms points to reference-image + voxel coordinates with ``ref_image_keys`` before generating heatmaps. If the points are a ``MetaTensor`` + with their own affine, that affine is used as the point-to-world transform. + visibility_keys: optional keys to store a boolean visibility mask for each point after coordinate conversion. + The value is ``True`` when the transformed point is finite and inside the heatmap spatial shape. spatial_shape: spatial dimensions of output heatmaps. Can be: - Single shape (tuple): applied to all keys - List of shapes: one per key (must match keys length) @@ -542,6 +548,7 @@ class GenerateHeatmapd(MapTransform): ValueError: If heatmap_keys/ref_image_keys length doesn't match keys length. ValueError: If no spatial shape can be determined (need spatial_shape or ref_image_keys). ValueError: If input points have invalid shape (must be 2D array with shape (N, D)). + ValueError: If ``coordinate_space="world"`` is used without a reference affine. Example: .. code-block:: python @@ -573,12 +580,24 @@ class GenerateHeatmapd(MapTransform): result = transform(data) # result["landmarks_heatmap"] has shape (2, 64, 64) + # World-space landmarks can be converted against the reference affine. + transform = GenerateHeatmapd( + keys="landmarks_world", + heatmap_keys="landmark_heatmap", + ref_image_keys="image", + coordinate_space="world", + visibility_keys="landmark_visible", + sigma=2.0, + ) + Notes: - Default heatmap_keys are generated as "{key}_heatmap" for each input key - Shape inference precedence: static spatial_shape > ref_image - Input points shape: (N, D) where N is number of landmarks, D is spatial dimensions - Output heatmap shape: (N, H, W) for 2D or (N, H, W, D) for 3D - When using ref_image_keys, heatmaps inherit affine and spatial metadata from reference + - ``coordinate_space="world"`` assumes that the points and reference affine use the same world-coordinate + convention. Convert LPS/RAS conventions before calling this transform if needed. """ backend = GenerateHeatmap.backend @@ -590,6 +609,11 @@ class GenerateHeatmapd(MapTransform): _ERR_NO_SHAPE = "Unable to determine spatial shape for GenerateHeatmapd. Provide spatial_shape or ref_image_keys." _ERR_INVALID_POINTS = "Landmark arrays must be 2D with shape (N, D)." _ERR_REF_NO_SHAPE = "Reference data must define a shape attribute." + _ERR_VISIBILITY_KEYS_LEN = "Argument `visibility_keys` length must match keys length when provided." + _ERR_COORDINATE_SPACE_LEN = ( + "Argument `coordinate_space` length must match keys length when providing per-key values." + ) + _SUPPORTED_COORDINATE_SPACES = {"voxel", "world"} def __init__( self, @@ -597,6 +621,8 @@ def __init__( sigma: Sequence[float] | float = 5.0, heatmap_keys: KeysCollection | None = None, ref_image_keys: KeysCollection | None = None, + coordinate_space: str | Sequence[str] = "voxel", + visibility_keys: KeysCollection | None = None, spatial_shape: Sequence[int] | Sequence[Sequence[int]] | None = None, truncated: float = 4.0, normalize: bool = True, @@ -606,22 +632,27 @@ def __init__( super().__init__(keys, allow_missing_keys) self.heatmap_keys = self._prepare_heatmap_keys(heatmap_keys) self.ref_image_keys = self._prepare_optional_keys(ref_image_keys) + self.coordinate_spaces = self._prepare_coordinate_spaces(coordinate_space) + self.visibility_keys = self._prepare_visibility_keys(visibility_keys) self.static_shapes = self._prepare_shapes(spatial_shape) self.generator = GenerateHeatmap( sigma=sigma, spatial_shape=None, truncated=truncated, normalize=normalize, dtype=dtype ) + self.world_to_voxel = ApplyTransformToPoints(dtype=torch.float32, invert_affine=True) def __call__(self, data: Mapping[Hashable, Any]) -> dict[Hashable, Any]: d = dict(data) - for key, out_key, ref_key, static_shape in self.key_iterator( - d, self.heatmap_keys, self.ref_image_keys, self.static_shapes + for key, out_key, ref_key, coordinate_space, visibility_key, static_shape in self.key_iterator( + d, self.heatmap_keys, self.ref_image_keys, self.coordinate_spaces, self.visibility_keys, self.static_shapes ): points = d[key] shape = self._determine_shape(points, static_shape, d, ref_key) + reference = d.get(ref_key) if ref_key is not None and ref_key in d else None + points = self._convert_points(points, reference, coordinate_space) + visibility = self._compute_visibility(points, shape) # The GenerateHeatmap transform will handle type conversion based on input points heatmap = self.generator(points, spatial_shape=shape) # If there's a reference image and we need to match its type/device - reference = d.get(ref_key) if ref_key is not None and ref_key in d else None if reference is not None and isinstance(reference, (torch.Tensor, np.ndarray)): # Convert to match reference type and device while preserving heatmap's dtype heatmap, _, _ = convert_to_dst_type( @@ -632,6 +663,8 @@ def __call__(self, data: Mapping[Hashable, Any]) -> dict[Hashable, Any]: heatmap.affine = reference.affine self._update_spatial_metadata(heatmap, shape) d[out_key] = heatmap + if visibility_key is not None: + d[visibility_key] = self._convert_visibility(visibility, d[key]) return d def _prepare_heatmap_keys(self, heatmap_keys: KeysCollection | None) -> tuple[Hashable, ...]: @@ -654,6 +687,34 @@ def _prepare_optional_keys(self, maybe_keys: KeysCollection | None) -> tuple[Has raise ValueError(self._ERR_REF_KEYS_LEN) return tuple(keys_tuple) + def _prepare_visibility_keys(self, maybe_keys: KeysCollection | None) -> tuple[Hashable | None, ...]: + if maybe_keys is None: + return (None,) * len(self.keys) + keys_tuple = ensure_tuple(maybe_keys) + if len(keys_tuple) == 1 and len(self.keys) > 1: + keys_tuple = keys_tuple * len(self.keys) + if len(keys_tuple) != len(self.keys): + raise ValueError(self._ERR_VISIBILITY_KEYS_LEN) + return tuple(keys_tuple) + + def _prepare_coordinate_spaces(self, coordinate_space: str | Sequence[str]) -> tuple[str, ...]: + if isinstance(coordinate_space, str): + spaces = (coordinate_space,) * len(self.keys) + else: + spaces = ensure_tuple(coordinate_space) + if len(spaces) == 1 and len(self.keys) > 1: + spaces = spaces * len(self.keys) + if len(spaces) != len(self.keys): + raise ValueError(self._ERR_COORDINATE_SPACE_LEN) + spaces = tuple(str(space).lower() for space in spaces) + invalid = set(spaces) - self._SUPPORTED_COORDINATE_SPACES + if invalid: + raise ValueError( + f"Unsupported coordinate_space value: {sorted(invalid)}. " + f"Supported values are {sorted(self._SUPPORTED_COORDINATE_SPACES)}." + ) + return spaces + def _prepare_shapes( self, spatial_shape: Sequence[int] | Sequence[Sequence[int]] | None ) -> tuple[tuple[int, ...] | None, ...]: @@ -711,6 +772,46 @@ def _update_spatial_metadata(self, heatmap: MetaTensor, spatial_shape: tuple[int """Set spatial_shape explicitly from resolved shape.""" heatmap.meta["spatial_shape"] = tuple(int(v) for v in spatial_shape) + def _convert_points(self, points: Any, reference: Any, coordinate_space: str) -> Any: + if coordinate_space == "voxel": + return points + + affine = self._get_reference_affine(reference) + points_t = convert_to_tensor(points, dtype=torch.float32, track_meta=False) + if points_t.ndim != 2: + raise ValueError(f"{self._ERR_INVALID_POINTS} Got {points_t.ndim}D tensor.") + + if isinstance(points, MetaTensor): + points_to_transform = points.unsqueeze(0) + else: + points_to_transform = points_t.unsqueeze(0) + converted = self.world_to_voxel(points_to_transform, affine).squeeze(0) + return converted + + def _get_reference_affine(self, reference: Any) -> torch.Tensor: + if reference is None: + raise ValueError("coordinate_space='world' requires ref_image_keys or a reference affine.") + affine = getattr(reference, "affine", None) + if affine is not None: + return affine + if isinstance(reference, (torch.Tensor, np.ndarray)) and reference.shape in ((3, 3), (4, 4)): + return reference + raise ValueError("coordinate_space='world' requires reference data with an affine matrix.") + + def _compute_visibility(self, points: Any, spatial_shape: tuple[int, ...]) -> torch.Tensor: + points_t = convert_to_tensor(points, dtype=torch.float32, track_meta=False) + if points_t.ndim != 2: + raise ValueError(f"{self._ERR_INVALID_POINTS} Got {points_t.ndim}D tensor.") + bounds = torch.as_tensor(spatial_shape, dtype=points_t.dtype, device=points_t.device) + return torch.isfinite(points_t).all(dim=1) & (points_t >= 0).all(dim=1) & (points_t < bounds).all(dim=1) + + def _convert_visibility(self, visibility: torch.Tensor, points: Any) -> NdarrayOrTensor: + if isinstance(points, (MetaTensor, torch.Tensor)): + return visibility.to(device=points.device, dtype=torch.bool) + if isinstance(points, np.ndarray): + return visibility.cpu().numpy().astype(bool) + return visibility.to(dtype=torch.bool) + GenerateHeatmapD = GenerateHeatmapDict = GenerateHeatmapd diff --git a/tests/transforms/test_generate_heatmapd.py b/tests/transforms/test_generate_heatmapd.py index 0867a959e5..7c000a697f 100644 --- a/tests/transforms/test_generate_heatmapd.py +++ b/tests/transforms/test_generate_heatmapd.py @@ -21,6 +21,12 @@ from monai.transforms.post.dictionary import GenerateHeatmapd from tests.test_utils import assert_allclose + +def _peak_coord(channel: torch.Tensor) -> torch.Tensor: + idx = torch.argmax(channel) + return torch.stack(torch.unravel_index(idx, channel.shape)) + + # Test cases for dictionary transforms with reference image # Only test with non-MetaTensor types to avoid affine conflicts TEST_CASES_WITH_REF = [ @@ -220,6 +226,102 @@ def test_metatensor_points_with_ref(self): # Heatmap should inherit affine from the reference image assert_allclose(heatmap.affine, image.affine, type_test=False) + def test_world_points_with_reference_affine_and_visibility(self): + affine = torch.diag(torch.tensor([2.0, 2.0, 2.0, 1.0])) + image = MetaTensor(torch.zeros((1, 8, 8, 8), dtype=torch.float32), affine=affine) + image.meta["spatial_shape"] = (8, 8, 8) + points = torch.tensor( + [ + [4.0, 6.0, 8.0], # voxel coordinate [2, 3, 4] + [20.0, 0.0, 0.0], # out of bounds after affine conversion + [float("nan"), 0.0, 0.0], + ], + dtype=torch.float32, + ) + + transform = GenerateHeatmapd( + keys="points", + heatmap_keys="heatmap", + ref_image_keys="image", + coordinate_space="world", + visibility_keys="visible", + sigma=1.0, + ) + result = transform({"points": points, "image": image}) + + heatmap = result["heatmap"] + self.assertIsInstance(heatmap, MetaTensor) + self.assertEqual(tuple(heatmap.shape), (3, 8, 8, 8)) + assert_allclose(_peak_coord(heatmap[0]), torch.tensor([2, 3, 4]), type_test=False) + self.assertTrue(torch.equal(result["visible"], torch.tensor([True, False, False]))) + self.assertGreater(heatmap[0].max(), 0.99) + self.assertEqual(float(heatmap[1].max()), 0.0) + self.assertEqual(float(heatmap[2].max()), 0.0) + + def test_world_points_with_translated_rotated_affine(self): + affine = torch.tensor( + [[0.0, -2.0, 0.0, 10.0], [3.0, 0.0, 0.0, 20.0], [0.0, 0.0, 4.0, 30.0], [0.0, 0.0, 0.0, 1.0]], + dtype=torch.float32, + ) + image = MetaTensor(torch.zeros((1, 8, 8, 8), dtype=torch.float32), affine=affine) + image.meta["spatial_shape"] = (8, 8, 8) + voxel_point = torch.tensor([2.0, 3.0, 4.0], dtype=torch.float32) + world_point = affine[:3, :3] @ voxel_point + affine[:3, 3] + + transform = GenerateHeatmapd( + keys="points", + heatmap_keys="heatmap", + ref_image_keys="image", + coordinate_space="world", + visibility_keys="visible", + sigma=1.0, + ) + result = transform({"points": world_point[None], "image": image}) + + assert_allclose(_peak_coord(result["heatmap"][0]), voxel_point.to(torch.long), type_test=False) + self.assertTrue(torch.equal(result["visible"], torch.tensor([True]))) + + def test_world_metatensor_points_use_point_affine(self): + image = MetaTensor(torch.zeros((1, 8, 8, 8), dtype=torch.float32), affine=torch.eye(4)) + image.meta["spatial_shape"] = (8, 8, 8) + points_affine = torch.diag(torch.tensor([2.0, 2.0, 2.0, 1.0])) + points = MetaTensor(torch.tensor([[1.0, 2.0, 3.0]], dtype=torch.float32), affine=points_affine) + + transform = GenerateHeatmapd( + keys="points", + heatmap_keys="heatmap", + ref_image_keys="image", + coordinate_space="world", + visibility_keys="visible", + sigma=1.0, + ) + result = transform({"points": points, "image": image}) + + assert_allclose(_peak_coord(result["heatmap"][0]), torch.tensor([2, 4, 6]), type_test=False) + self.assertIsInstance(result["visible"], torch.Tensor) + self.assertNotIsInstance(result["visible"], MetaTensor) + self.assertTrue(bool(result["visible"][0])) + + def test_world_points_require_reference_affine(self): + transform = GenerateHeatmapd( + keys="points", heatmap_keys="heatmap", spatial_shape=(8, 8, 8), coordinate_space="world" + ) + with self.assertRaisesRegex(ValueError, "reference|affine|ref_image_keys"): + transform({"points": torch.zeros((1, 3), dtype=torch.float32)}) + + def test_invalid_coordinate_space_raises(self): + with self.assertRaisesRegex(ValueError, "coordinate_space"): + GenerateHeatmapd(keys="points", heatmap_keys="heatmap", spatial_shape=(8, 8), coordinate_space="scanner") + + def test_visibility_key_length_mismatch_raises(self): + with self.assertRaises(ValueError): + GenerateHeatmapd( + keys=["pts1", "pts2"], + heatmap_keys=["hm1", "hm2"], + visibility_keys=["visible1", "visible2", "visible3"], + spatial_shape=(8, 8), + ) + if __name__ == "__main__": unittest.main()