From f49d6b2c0d86ab47c4e666933ff2862a0b9ee717 Mon Sep 17 00:00:00 2001 From: Soumya Snigdha Kundu Date: Mon, 29 Jun 2026 17:56:24 +0100 Subject: [PATCH 1/2] feat(transforms): add randomize_per_key option to random dict transforms (#7561) Random dictionary transforms draw their parameters once and apply that same draw to every key. This adds an opt-in randomize_per_key flag so each key can draw independently instead, which is useful when keys hold unrelated images such as different views in self-supervised learning. This mirrors TorchIO's per-transform randomization, but as an opt-in flag rather than a change to the Randomizable base class. Default is False, so existing behaviour is unchanged. The flag is added to the eleven random intensity transforms and to four spatial transforms (RandAxisFlipd, RandRotated, RandZoomd, RandGridDistortiond); other transforms are left out for now because they precompute a shared sampling grid, emit multiple correlated samples, or mix across keys, so per-key randomization there needs more than a flag. Signed-off-by: Soumya Snigdha Kundu --- monai/transforms/intensity/dictionary.py | 121 ++++++++++++++++------- monai/transforms/spatial/dictionary.py | 58 ++++++++--- 2 files changed, 134 insertions(+), 45 deletions(-) diff --git a/monai/transforms/intensity/dictionary.py b/monai/transforms/intensity/dictionary.py index 0c25d4ac99..101852a88c 100644 --- a/monai/transforms/intensity/dictionary.py +++ b/monai/transforms/intensity/dictionary.py @@ -188,6 +188,8 @@ class RandGaussianNoised(RandomizableTransform, MapTransform): dtype: output data type, if None, same as input image. defaults to float32. allow_missing_keys: don't raise exception if key is missing. sample_std: If True, sample the spread of the Gaussian distribution uniformly from 0 to std. + randomize_per_key: if True, draw independent random parameters for each key instead of sharing + them across all keys (e.g. for independent views in self-supervised learning). Defaults to False. """ backend = RandGaussianNoise.backend @@ -201,9 +203,11 @@ def __init__( dtype: DtypeLike = np.float32, allow_missing_keys: bool = False, sample_std: bool = True, + randomize_per_key: bool = False, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) + self.randomize_per_key = randomize_per_key self.rand_gaussian_noise = RandGaussianNoise(mean=mean, std=std, prob=1.0, dtype=dtype, sample_std=sample_std) def set_random_state( @@ -221,17 +225,18 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N d[key] = convert_to_tensor(d[key], track_meta=get_track_meta()) return d - # all the keys share the same random noise first_key: Hashable = self.first_key(d) if first_key == (): for key in self.key_iterator(d): d[key] = convert_to_tensor(d[key], track_meta=get_track_meta()) return d - self.rand_gaussian_noise.randomize(d[first_key]) + if not self.randomize_per_key: + # all the keys share the same random noise + self.rand_gaussian_noise.randomize(d[first_key]) for key in self.key_iterator(d): - d[key] = self.rand_gaussian_noise(img=d[key], randomize=False) + d[key] = self.rand_gaussian_noise(img=d[key], randomize=self.randomize_per_key) return d @@ -381,6 +386,7 @@ def __init__( prob: float = 0.1, channel_wise: bool = False, allow_missing_keys: bool = False, + randomize_per_key: bool = False, ) -> None: """ Args: @@ -409,10 +415,13 @@ def __init__( channel_wise: if True, shift intensity on each channel separately. For each channel, a random offset will be chosen. Please ensure that the first dimension represents the channel of the image if True. allow_missing_keys: don't raise exception if key is missing. + randomize_per_key: if True, draw independent random parameters for each key instead of sharing + them across all keys (e.g. for independent views in self-supervised learning). Defaults to False. """ MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) + self.randomize_per_key = randomize_per_key self.factor_key = ensure_tuple_rep(factor_key, len(self.keys)) self.meta_keys = ensure_tuple_rep(None, len(self.keys)) if meta_keys is None else ensure_tuple(meta_keys) if len(self.keys) != len(self.meta_keys): @@ -442,14 +451,15 @@ def __call__(self, data) -> dict[Hashable, NdarrayOrTensor]: d[key] = convert_to_tensor(d[key], track_meta=get_track_meta()) return d - # all the keys share the same random shift factor - self.shifter.randomize(d[first_key]) + if not self.randomize_per_key: + # all the keys share the same random shift factor + self.shifter.randomize(d[first_key]) for key, factor_key, meta_key, meta_key_postfix in self.key_iterator( d, self.factor_key, self.meta_keys, self.meta_key_postfix ): meta_key = meta_key or f"{key}_{meta_key_postfix}" factor: float | None = d[meta_key].get(factor_key) if meta_key in d else None - d[key] = self.shifter(d[key], factor=factor, randomize=False) + d[key] = self.shifter(d[key], factor=factor, randomize=self.randomize_per_key) return d @@ -506,6 +516,7 @@ def __init__( channel_wise: bool = False, dtype: DtypeLike = np.float32, allow_missing_keys: bool = False, + randomize_per_key: bool = False, ) -> None: """ Args: @@ -518,9 +529,12 @@ def __init__( channel_wise: if True, calculate on each channel separately. dtype: output data type, if None, same as input image. defaults to float32. allow_missing_keys: don't raise exception if key is missing. + randomize_per_key: if True, draw independent random parameters for each key instead of sharing + them across all keys (e.g. for independent views in self-supervised learning). Defaults to False. """ MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) + self.randomize_per_key = randomize_per_key self.shifter = RandStdShiftIntensity( factors=factors, nonzero=nonzero, channel_wise=channel_wise, dtype=dtype, prob=1.0 ) @@ -540,10 +554,11 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N d[key] = convert_to_tensor(d[key], track_meta=get_track_meta()) return d - # all the keys share the same random shift factor - self.shifter.randomize(None) + if not self.randomize_per_key: + # all the keys share the same random shift factor + self.shifter.randomize(None) for key in self.key_iterator(d): - d[key] = self.shifter(d[key], randomize=False) + d[key] = self.shifter(d[key], randomize=self.randomize_per_key) return d @@ -605,6 +620,7 @@ def __init__( channel_wise: bool = False, dtype: DtypeLike = np.float32, allow_missing_keys: bool = False, + randomize_per_key: bool = False, ) -> None: """ Args: @@ -618,10 +634,13 @@ def __init__( that the first dimension represents the channel of the image if True. dtype: output data type, if None, same as input image. defaults to float32. allow_missing_keys: don't raise exception if key is missing. + randomize_per_key: if True, draw independent random parameters for each key instead of sharing + them across all keys (e.g. for independent views in self-supervised learning). Defaults to False. """ MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) + self.randomize_per_key = randomize_per_key self.scaler = RandScaleIntensity(factors=factors, dtype=dtype, prob=1.0, channel_wise=channel_wise) def set_random_state( @@ -646,10 +665,11 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N d[key] = convert_to_tensor(d[key], track_meta=get_track_meta()) return d - # all the keys share the same random scale factor - self.scaler.randomize(d[first_key]) + if not self.randomize_per_key: + # all the keys share the same random scale factor + self.scaler.randomize(d[first_key]) for key in self.key_iterator(d): - d[key] = self.scaler(d[key], randomize=False) + d[key] = self.scaler(d[key], randomize=self.randomize_per_key) return d @@ -672,6 +692,7 @@ def __init__( dtype: DtypeLike = np.float32, allow_missing_keys: bool = False, channel_wise: bool = False, + randomize_per_key: bool = False, ) -> None: """ Args: @@ -687,10 +708,13 @@ def __init__( channel_wise: if True, scale on each channel separately. `preserve_range` and `fixed_mean` are also applied on each channel separately if `channel_wise` is True. Please ensure that the first dimension represents the channel of the image if True. + randomize_per_key: if True, draw independent random parameters for each key instead of sharing + them across all keys (e.g. for independent views in self-supervised learning). Defaults to False. """ MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) + self.randomize_per_key = randomize_per_key self.fixed_mean = fixed_mean self.preserve_range = preserve_range self.scaler = RandScaleIntensityFixedMean( @@ -724,10 +748,11 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N d[key] = convert_to_tensor(d[key], track_meta=get_track_meta()) return d - # all the keys share the same random scale factor - self.scaler.randomize(d[first_key]) + if not self.randomize_per_key: + # all the keys share the same random scale factor + self.scaler.randomize(d[first_key]) for key in self.key_iterator(d): - d[key] = self.scaler(d[key], randomize=False) + d[key] = self.scaler(d[key], randomize=self.randomize_per_key) return d @@ -746,6 +771,7 @@ def __init__( dtype: DtypeLike = np.float32, prob: float = 0.1, allow_missing_keys: bool = False, + randomize_per_key: bool = False, ) -> None: """ Args: @@ -757,11 +783,14 @@ def __init__( dtype: output data type, if None, same as input image. defaults to float32. prob: probability to do random bias field. allow_missing_keys: don't raise exception if key is missing. + randomize_per_key: if True, draw independent random parameters for each key instead of sharing + them across all keys (e.g. for independent views in self-supervised learning). Defaults to False. """ MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) + self.randomize_per_key = randomize_per_key self.rand_bias_field = RandBiasField(degree=degree, coeff_range=coeff_range, dtype=dtype, prob=1.0) def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> RandBiasFieldd: @@ -777,17 +806,18 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N d[key] = convert_to_tensor(d[key], track_meta=get_track_meta()) return d - # all the keys share the same random bias factor first_key: Hashable = self.first_key(d) if first_key == (): for key in self.key_iterator(d): d[key] = convert_to_tensor(d[key], track_meta=get_track_meta()) return d - self.rand_bias_field.randomize(img_size=d[first_key].shape[1:]) + if not self.randomize_per_key: + # all the keys share the same random bias factor + self.rand_bias_field.randomize(img_size=d[first_key].shape[1:]) for key in self.key_iterator(d): - d[key] = self.rand_bias_field(d[key], randomize=False) + d[key] = self.rand_bias_field(d[key], randomize=self.randomize_per_key) return d @@ -1003,6 +1033,8 @@ class RandAdjustContrastd(RandomizableTransform, MapTransform): `_ function. allow_missing_keys: don't raise exception if key is missing. + randomize_per_key: if True, draw independent random parameters for each key instead of sharing + them across all keys (e.g. for independent views in self-supervised learning). Defaults to False. """ backend = RandAdjustContrast.backend @@ -1015,9 +1047,11 @@ def __init__( invert_image: bool = False, retain_stats: bool = False, allow_missing_keys: bool = False, + randomize_per_key: bool = False, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) + self.randomize_per_key = randomize_per_key self.adjuster = RandAdjustContrast(gamma=gamma, prob=1.0, invert_image=invert_image, retain_stats=retain_stats) self.invert_image = invert_image @@ -1036,10 +1070,11 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N d[key] = convert_to_tensor(d[key], track_meta=get_track_meta()) return d - # all the keys share the same random gamma value - self.adjuster.randomize(None) + if not self.randomize_per_key: + # all the keys share the same random gamma value + self.adjuster.randomize(None) for key in self.key_iterator(d): - d[key] = self.adjuster(d[key], randomize=False) + d[key] = self.adjuster(d[key], randomize=self.randomize_per_key) return d @@ -1242,6 +1277,8 @@ class RandGaussianSmoothd(RandomizableTransform, MapTransform): see also :py:meth:`monai.networks.layers.GaussianFilter`. prob: probability of Gaussian smooth. allow_missing_keys: don't raise exception if key is missing. + randomize_per_key: if True, draw independent random parameters for each key instead of sharing + them across all keys (e.g. for independent views in self-supervised learning). Defaults to False. """ @@ -1256,9 +1293,11 @@ def __init__( approx: str = "erf", prob: float = 0.1, allow_missing_keys: bool = False, + randomize_per_key: bool = False, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) + self.randomize_per_key = randomize_per_key self.rand_smooth = RandGaussianSmooth( sigma_x=sigma_x, sigma_y=sigma_y, sigma_z=sigma_z, approx=approx, prob=1.0 ) @@ -1278,10 +1317,11 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N d[key] = convert_to_tensor(d[key], track_meta=get_track_meta()) return d - # all the keys share the same random sigma - self.rand_smooth.randomize(None) + if not self.randomize_per_key: + # all the keys share the same random sigma + self.rand_smooth.randomize(None) for key in self.key_iterator(d): - d[key] = self.rand_smooth(d[key], randomize=False) + d[key] = self.rand_smooth(d[key], randomize=self.randomize_per_key) return d @@ -1347,6 +1387,8 @@ class RandGaussianSharpend(RandomizableTransform, MapTransform): see also :py:meth:`monai.networks.layers.GaussianFilter`. prob: probability of Gaussian sharpen. allow_missing_keys: don't raise exception if key is missing. + randomize_per_key: if True, draw independent random parameters for each key instead of sharing + them across all keys (e.g. for independent views in self-supervised learning). Defaults to False. """ @@ -1365,9 +1407,11 @@ def __init__( approx: str = "erf", prob: float = 0.1, allow_missing_keys: bool = False, + randomize_per_key: bool = False, ): MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) + self.randomize_per_key = randomize_per_key self.rand_sharpen = RandGaussianSharpen( sigma1_x=sigma1_x, sigma1_y=sigma1_y, @@ -1395,10 +1439,11 @@ def __call__(self, data: dict[Hashable, NdarrayOrTensor]) -> dict[Hashable, Ndar d[key] = convert_to_tensor(d[key], track_meta=get_track_meta()) return d - # all the keys share the same random sigma1, sigma2, etc. - self.rand_sharpen.randomize(None) + if not self.randomize_per_key: + # all the keys share the same random sigma1, sigma2, etc. + self.rand_sharpen.randomize(None) for key in self.key_iterator(d): - d[key] = self.rand_sharpen(d[key], randomize=False) + d[key] = self.rand_sharpen(d[key], randomize=self.randomize_per_key) return d @@ -1415,6 +1460,8 @@ class RandHistogramShiftd(RandomizableTransform, MapTransform): control points selecting from range (min_value, max_value). prob: probability of histogram shift. allow_missing_keys: don't raise exception if key is missing. + randomize_per_key: if True, draw independent random parameters for each key instead of sharing + them across all keys (e.g. for independent views in self-supervised learning). Defaults to False. """ backend = RandHistogramShift.backend @@ -1425,9 +1472,11 @@ def __init__( num_control_points: tuple[int, int] | int = 10, prob: float = 0.1, allow_missing_keys: bool = False, + randomize_per_key: bool = False, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) + self.randomize_per_key = randomize_per_key self.shifter = RandHistogramShift(num_control_points=num_control_points, prob=1.0) def set_random_state( @@ -1445,10 +1494,11 @@ def __call__(self, data: dict[Hashable, NdarrayOrTensor]) -> dict[Hashable, Ndar d[key] = convert_to_tensor(d[key], track_meta=get_track_meta()) return d - # all the keys share the same random shift params - self.shifter.randomize(None) + if not self.randomize_per_key: + # all the keys share the same random shift params + self.shifter.randomize(None) for key in self.key_iterator(d): - d[key] = self.shifter(d[key], randomize=False) + d[key] = self.shifter(d[key], randomize=self.randomize_per_key) return d @@ -1476,6 +1526,8 @@ class RandGibbsNoised(RandomizableTransform, MapTransform): uniformly from the interval [a,b]. If a float is given, then the value of alpha will be sampled uniformly from the interval [0, alpha]. allow_missing_keys: do not raise exception if key is missing. + randomize_per_key: if True, draw independent random parameters for each key instead of sharing + them across all keys (e.g. for independent views in self-supervised learning). Defaults to False. """ backend = RandGibbsNoise.backend @@ -1486,9 +1538,11 @@ def __init__( prob: float = 0.1, alpha: float | Sequence[float] = (0.0, 1.0), allow_missing_keys: bool = False, + randomize_per_key: bool = False, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob=prob) + self.randomize_per_key = randomize_per_key self.rand_gibbs_noise = RandGibbsNoise(alpha=alpha, prob=1.0) def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> RandGibbsNoised: @@ -1504,10 +1558,11 @@ def __call__(self, data: dict[Hashable, NdarrayOrTensor]) -> dict[Hashable, Ndar d[key] = convert_to_tensor(d[key], track_meta=get_track_meta()) return d - # all the keys share the same random noise params - self.rand_gibbs_noise.randomize(None) + if not self.randomize_per_key: + # all the keys share the same random noise params + self.rand_gibbs_noise.randomize(None) for key in self.key_iterator(d): - d[key] = self.rand_gibbs_noise(d[key], randomize=False) + d[key] = self.rand_gibbs_noise(d[key], randomize=self.randomize_per_key) return d diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 51ad0435fc..c987a86a68 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -1657,16 +1657,26 @@ class RandAxisFlipd(RandomizableTransform, MapTransform, InvertibleTransform, La allow_missing_keys: don't raise exception if key is missing. lazy: a flag to indicate whether this transform should execute lazily or not. Defaults to False + randomize_per_key: if True, draw an independent random axis for each key instead of sharing one + across all keys (e.g. for independent views in self-supervised learning). Note this breaks the + spatial correspondence between keys, so keep it False for aligned data such as image/label + pairs. Defaults to False. """ backend = RandAxisFlip.backend def __init__( - self, keys: KeysCollection, prob: float = 0.1, allow_missing_keys: bool = False, lazy: bool = False + self, + keys: KeysCollection, + prob: float = 0.1, + allow_missing_keys: bool = False, + lazy: bool = False, + randomize_per_key: bool = False, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) LazyTransform.__init__(self, lazy=lazy) + self.randomize_per_key = randomize_per_key self.flipper = RandAxisFlip(prob=1.0, lazy=lazy) @LazyTransform.lazy.setter # type: ignore @@ -1699,13 +1709,14 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = No self.randomize(None) - # all the keys share the same random selected axis - self.flipper.randomize(d[first_key]) + if not self.randomize_per_key: + # all the keys share the same random selected axis + self.flipper.randomize(d[first_key]) lazy_ = self.lazy if lazy is None else lazy for key in self.key_iterator(d): if self._do_transform: - d[key] = self.flipper(d[key], randomize=False, lazy=lazy_) + d[key] = self.flipper(d[key], randomize=self.randomize_per_key, lazy=lazy_) else: d[key] = convert_to_tensor(d[key], track_meta=get_track_meta()) self.push_transform(d[key], replace=True, lazy=lazy_) @@ -1850,6 +1861,10 @@ class RandRotated(RandomizableTransform, MapTransform, InvertibleTransform, Lazy allow_missing_keys: don't raise exception if key is missing. lazy: a flag to indicate whether this transform should execute lazily or not. Defaults to False + randomize_per_key: if True, draw independent random parameters for each key instead of sharing + them across all keys (e.g. for independent views in self-supervised learning). Note this breaks + the spatial correspondence between keys, so keep it False for aligned data such as image/label + pairs. Defaults to False. """ backend = RandRotate.backend @@ -1868,10 +1883,12 @@ def __init__( dtype: Sequence[DtypeLike | torch.dtype] | DtypeLike | torch.dtype = np.float32, allow_missing_keys: bool = False, lazy: bool = False, + randomize_per_key: bool = False, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) LazyTransform.__init__(self, lazy=lazy) + self.randomize_per_key = randomize_per_key self.rand_rotate = RandRotate( range_x=range_x, range_y=range_y, range_z=range_z, prob=1.0, keep_size=keep_size, lazy=lazy ) @@ -1906,8 +1923,9 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = No d = dict(data) self.randomize(None) - # all the keys share the same random rotate angle - self.rand_rotate.randomize() + if not self.randomize_per_key: + # all the keys share the same random rotate angle + self.rand_rotate.randomize() lazy_ = self.lazy if lazy is None else lazy for key, mode, padding_mode, align_corners, dtype in self.key_iterator( @@ -1920,7 +1938,7 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = No padding_mode=padding_mode, align_corners=align_corners, dtype=dtype, - randomize=False, + randomize=self.randomize_per_key, lazy=lazy_, ) else: @@ -2076,6 +2094,10 @@ class RandZoomd(RandomizableTransform, MapTransform, InvertibleTransform, LazyTr allow_missing_keys: don't raise exception if key is missing. lazy: a flag to indicate whether this transform should execute lazily or not. Defaults to False + randomize_per_key: if True, draw independent random parameters for each key instead of sharing + them across all keys (e.g. for independent views in self-supervised learning). Note this breaks + the spatial correspondence between keys, so keep it False for aligned data such as image/label + pairs. Defaults to False. kwargs: other args for `np.pad` API, note that `np.pad` treats channel dimension as the first dimension. more details: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html """ @@ -2095,11 +2117,13 @@ def __init__( keep_size: bool = True, allow_missing_keys: bool = False, lazy: bool = False, + randomize_per_key: bool = False, **kwargs, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) LazyTransform.__init__(self, lazy=lazy) + self.randomize_per_key = randomize_per_key self.rand_zoom = RandZoom( prob=1.0, min_zoom=min_zoom, max_zoom=max_zoom, keep_size=keep_size, lazy=lazy, **kwargs ) @@ -2139,8 +2163,9 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = No self.randomize(None) - # all the keys share the same random zoom factor - self.rand_zoom.randomize(d[first_key]) + if not self.randomize_per_key: + # all the keys share the same random zoom factor + self.rand_zoom.randomize(d[first_key]) lazy_ = self.lazy if lazy is None else lazy for key, mode, padding_mode, align_corners, dtype in self.key_iterator( @@ -2153,7 +2178,7 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = No padding_mode=padding_mode, align_corners=align_corners, dtype=dtype, - randomize=False, + randomize=self.randomize_per_key, lazy=lazy_, ) else: @@ -2250,6 +2275,7 @@ def __init__( padding_mode: str = GridSamplePadMode.BORDER, device: torch.device | None = None, allow_missing_keys: bool = False, + randomize_per_key: bool = False, ) -> None: """ Args: @@ -2275,10 +2301,15 @@ def __init__( It also can be a sequence, each element corresponds to a key in ``keys``. device: device on which the tensor will be allocated. allow_missing_keys: don't raise exception if key is missing. + randomize_per_key: if True, draw independent random parameters for each key instead of sharing + them across all keys (e.g. for independent views in self-supervised learning). Note this + breaks the spatial correspondence between keys, so keep it False for aligned data such as + image/label pairs. Defaults to False. """ MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) + self.randomize_per_key = randomize_per_key self.rand_grid_distortion = RandGridDistortion( num_cells=num_cells, prob=1.0, distort_limit=distort_limit, device=device ) @@ -2314,10 +2345,13 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torc return out if isinstance(d[first_key], MetaTensor) and d[first_key].pending_operations: # type: ignore warnings.warn(f"data['{first_key}'] has pending operations, transform may return incorrect results.") - self.rand_grid_distortion.randomize(d[first_key].shape[1:]) + if not self.randomize_per_key: + self.rand_grid_distortion.randomize(d[first_key].shape[1:]) for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode): - d[key] = self.rand_grid_distortion(d[key], mode=mode, padding_mode=padding_mode, randomize=False) + d[key] = self.rand_grid_distortion( + d[key], mode=mode, padding_mode=padding_mode, randomize=self.randomize_per_key + ) return d From 377d3ead00c8314ab235629dcb6e7eecaa2bc884 Mon Sep 17 00:00:00 2001 From: Soumya Snigdha Kundu Date: Mon, 29 Jun 2026 17:56:24 +0100 Subject: [PATCH 2/2] tests(transforms): cover randomize_per_key across intensity and spatial dict transforms Signed-off-by: Soumya Snigdha Kundu --- tests/transforms/test_rand_per_key.py | 82 +++++++++++++++++++++++++++ 1 file changed, 82 insertions(+) create mode 100644 tests/transforms/test_rand_per_key.py diff --git a/tests/transforms/test_rand_per_key.py b/tests/transforms/test_rand_per_key.py new file mode 100644 index 0000000000..9fe7277bfe --- /dev/null +++ b/tests/transforms/test_rand_per_key.py @@ -0,0 +1,82 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.transforms.intensity.dictionary import ( + RandAdjustContrastd, + RandBiasFieldd, + RandGaussianNoised, + RandGaussianSharpend, + RandGaussianSmoothd, + RandGibbsNoised, + RandHistogramShiftd, + RandScaleIntensityd, + RandScaleIntensityFixedMeand, + RandShiftIntensityd, + RandStdShiftIntensityd, +) +from monai.transforms.spatial.dictionary import RandAxisFlipd, RandGridDistortiond, RandRotated, RandZoomd +from tests.test_utils import assert_allclose + +KEYS = ["img1", "img2"] + +TESTS = [ + (RandGaussianNoised, {}), + (RandShiftIntensityd, {"offsets": 0.5}), + (RandStdShiftIntensityd, {"factors": 0.5}), + (RandScaleIntensityd, {"factors": 0.5}), + (RandScaleIntensityFixedMeand, {"factors": 0.5}), + (RandBiasFieldd, {}), + (RandAdjustContrastd, {}), + (RandGaussianSmoothd, {}), + (RandGaussianSharpend, {}), + (RandHistogramShiftd, {"num_control_points": (5, 20)}), + (RandGibbsNoised, {}), + (RandAxisFlipd, {}), + (RandRotated, {"range_x": 1.0, "range_y": 1.0, "range_z": 1.0}), + (RandZoomd, {"min_zoom": 0.7, "max_zoom": 1.3}), + (RandGridDistortiond, {"num_cells": 3, "distort_limit": 0.2}), +] + + +class TestRandPerKey(unittest.TestCase): + @parameterized.expand([(cls.__name__, cls, kwargs) for cls, kwargs in TESTS]) + def test_shared_default(self, _, cls, kwargs): + t = cls(keys=KEYS, prob=1.0, **kwargs) + t.set_random_state(0) + img = torch.rand(1, 8, 8, 8) + 1.0 + out = t({k: img.clone() for k in KEYS}) + assert_allclose(out["img1"], out["img2"], type_test=False) + + @parameterized.expand([(cls.__name__, cls, kwargs) for cls, kwargs in TESTS]) + def test_independent_per_key(self, _, cls, kwargs): + # independent draws may coincide for a single seed (e.g. RandAxisFlipd's discrete axis), + # so only require that some deterministic seed yields divergent per-key outputs + img = torch.rand(1, 8, 8, 8) + 1.0 + differs = False + for seed in range(10): + t = cls(keys=KEYS, prob=1.0, randomize_per_key=True, **kwargs) + t.set_random_state(seed) + out = t({k: img.clone() for k in KEYS}) + if not torch.allclose(out["img1"], out["img2"]): + differs = True + break + self.assertTrue(differs) + + +if __name__ == "__main__": + unittest.main()