diff --git a/monai/transforms/intensity/dictionary.py b/monai/transforms/intensity/dictionary.py index 0c25d4ac99..101852a88c 100644 --- a/monai/transforms/intensity/dictionary.py +++ b/monai/transforms/intensity/dictionary.py @@ -188,6 +188,8 @@ class RandGaussianNoised(RandomizableTransform, MapTransform): dtype: output data type, if None, same as input image. defaults to float32. allow_missing_keys: don't raise exception if key is missing. sample_std: If True, sample the spread of the Gaussian distribution uniformly from 0 to std. + randomize_per_key: if True, draw independent random parameters for each key instead of sharing + them across all keys (e.g. for independent views in self-supervised learning). Defaults to False. """ backend = RandGaussianNoise.backend @@ -201,9 +203,11 @@ def __init__( dtype: DtypeLike = np.float32, allow_missing_keys: bool = False, sample_std: bool = True, + randomize_per_key: bool = False, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) + self.randomize_per_key = randomize_per_key self.rand_gaussian_noise = RandGaussianNoise(mean=mean, std=std, prob=1.0, dtype=dtype, sample_std=sample_std) def set_random_state( @@ -221,17 +225,18 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N d[key] = convert_to_tensor(d[key], track_meta=get_track_meta()) return d - # all the keys share the same random noise first_key: Hashable = self.first_key(d) if first_key == (): for key in self.key_iterator(d): d[key] = convert_to_tensor(d[key], track_meta=get_track_meta()) return d - self.rand_gaussian_noise.randomize(d[first_key]) + if not self.randomize_per_key: + # all the keys share the same random noise + self.rand_gaussian_noise.randomize(d[first_key]) for key in self.key_iterator(d): - d[key] = self.rand_gaussian_noise(img=d[key], randomize=False) + d[key] = self.rand_gaussian_noise(img=d[key], randomize=self.randomize_per_key) return d @@ -381,6 +386,7 @@ def __init__( prob: float = 0.1, channel_wise: bool = False, allow_missing_keys: bool = False, + randomize_per_key: bool = False, ) -> None: """ Args: @@ -409,10 +415,13 @@ def __init__( channel_wise: if True, shift intensity on each channel separately. For each channel, a random offset will be chosen. Please ensure that the first dimension represents the channel of the image if True. allow_missing_keys: don't raise exception if key is missing. + randomize_per_key: if True, draw independent random parameters for each key instead of sharing + them across all keys (e.g. for independent views in self-supervised learning). Defaults to False. """ MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) + self.randomize_per_key = randomize_per_key self.factor_key = ensure_tuple_rep(factor_key, len(self.keys)) self.meta_keys = ensure_tuple_rep(None, len(self.keys)) if meta_keys is None else ensure_tuple(meta_keys) if len(self.keys) != len(self.meta_keys): @@ -442,14 +451,15 @@ def __call__(self, data) -> dict[Hashable, NdarrayOrTensor]: d[key] = convert_to_tensor(d[key], track_meta=get_track_meta()) return d - # all the keys share the same random shift factor - self.shifter.randomize(d[first_key]) + if not self.randomize_per_key: + # all the keys share the same random shift factor + self.shifter.randomize(d[first_key]) for key, factor_key, meta_key, meta_key_postfix in self.key_iterator( d, self.factor_key, self.meta_keys, self.meta_key_postfix ): meta_key = meta_key or f"{key}_{meta_key_postfix}" factor: float | None = d[meta_key].get(factor_key) if meta_key in d else None - d[key] = self.shifter(d[key], factor=factor, randomize=False) + d[key] = self.shifter(d[key], factor=factor, randomize=self.randomize_per_key) return d @@ -506,6 +516,7 @@ def __init__( channel_wise: bool = False, dtype: DtypeLike = np.float32, allow_missing_keys: bool = False, + randomize_per_key: bool = False, ) -> None: """ Args: @@ -518,9 +529,12 @@ def __init__( channel_wise: if True, calculate on each channel separately. dtype: output data type, if None, same as input image. defaults to float32. allow_missing_keys: don't raise exception if key is missing. + randomize_per_key: if True, draw independent random parameters for each key instead of sharing + them across all keys (e.g. for independent views in self-supervised learning). Defaults to False. """ MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) + self.randomize_per_key = randomize_per_key self.shifter = RandStdShiftIntensity( factors=factors, nonzero=nonzero, channel_wise=channel_wise, dtype=dtype, prob=1.0 ) @@ -540,10 +554,11 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N d[key] = convert_to_tensor(d[key], track_meta=get_track_meta()) return d - # all the keys share the same random shift factor - self.shifter.randomize(None) + if not self.randomize_per_key: + # all the keys share the same random shift factor + self.shifter.randomize(None) for key in self.key_iterator(d): - d[key] = self.shifter(d[key], randomize=False) + d[key] = self.shifter(d[key], randomize=self.randomize_per_key) return d @@ -605,6 +620,7 @@ def __init__( channel_wise: bool = False, dtype: DtypeLike = np.float32, allow_missing_keys: bool = False, + randomize_per_key: bool = False, ) -> None: """ Args: @@ -618,10 +634,13 @@ def __init__( that the first dimension represents the channel of the image if True. dtype: output data type, if None, same as input image. defaults to float32. allow_missing_keys: don't raise exception if key is missing. + randomize_per_key: if True, draw independent random parameters for each key instead of sharing + them across all keys (e.g. for independent views in self-supervised learning). Defaults to False. """ MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) + self.randomize_per_key = randomize_per_key self.scaler = RandScaleIntensity(factors=factors, dtype=dtype, prob=1.0, channel_wise=channel_wise) def set_random_state( @@ -646,10 +665,11 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N d[key] = convert_to_tensor(d[key], track_meta=get_track_meta()) return d - # all the keys share the same random scale factor - self.scaler.randomize(d[first_key]) + if not self.randomize_per_key: + # all the keys share the same random scale factor + self.scaler.randomize(d[first_key]) for key in self.key_iterator(d): - d[key] = self.scaler(d[key], randomize=False) + d[key] = self.scaler(d[key], randomize=self.randomize_per_key) return d @@ -672,6 +692,7 @@ def __init__( dtype: DtypeLike = np.float32, allow_missing_keys: bool = False, channel_wise: bool = False, + randomize_per_key: bool = False, ) -> None: """ Args: @@ -687,10 +708,13 @@ def __init__( channel_wise: if True, scale on each channel separately. `preserve_range` and `fixed_mean` are also applied on each channel separately if `channel_wise` is True. Please ensure that the first dimension represents the channel of the image if True. + randomize_per_key: if True, draw independent random parameters for each key instead of sharing + them across all keys (e.g. for independent views in self-supervised learning). Defaults to False. """ MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) + self.randomize_per_key = randomize_per_key self.fixed_mean = fixed_mean self.preserve_range = preserve_range self.scaler = RandScaleIntensityFixedMean( @@ -724,10 +748,11 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N d[key] = convert_to_tensor(d[key], track_meta=get_track_meta()) return d - # all the keys share the same random scale factor - self.scaler.randomize(d[first_key]) + if not self.randomize_per_key: + # all the keys share the same random scale factor + self.scaler.randomize(d[first_key]) for key in self.key_iterator(d): - d[key] = self.scaler(d[key], randomize=False) + d[key] = self.scaler(d[key], randomize=self.randomize_per_key) return d @@ -746,6 +771,7 @@ def __init__( dtype: DtypeLike = np.float32, prob: float = 0.1, allow_missing_keys: bool = False, + randomize_per_key: bool = False, ) -> None: """ Args: @@ -757,11 +783,14 @@ def __init__( dtype: output data type, if None, same as input image. defaults to float32. prob: probability to do random bias field. allow_missing_keys: don't raise exception if key is missing. + randomize_per_key: if True, draw independent random parameters for each key instead of sharing + them across all keys (e.g. for independent views in self-supervised learning). Defaults to False. """ MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) + self.randomize_per_key = randomize_per_key self.rand_bias_field = RandBiasField(degree=degree, coeff_range=coeff_range, dtype=dtype, prob=1.0) def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> RandBiasFieldd: @@ -777,17 +806,18 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N d[key] = convert_to_tensor(d[key], track_meta=get_track_meta()) return d - # all the keys share the same random bias factor first_key: Hashable = self.first_key(d) if first_key == (): for key in self.key_iterator(d): d[key] = convert_to_tensor(d[key], track_meta=get_track_meta()) return d - self.rand_bias_field.randomize(img_size=d[first_key].shape[1:]) + if not self.randomize_per_key: + # all the keys share the same random bias factor + self.rand_bias_field.randomize(img_size=d[first_key].shape[1:]) for key in self.key_iterator(d): - d[key] = self.rand_bias_field(d[key], randomize=False) + d[key] = self.rand_bias_field(d[key], randomize=self.randomize_per_key) return d @@ -1003,6 +1033,8 @@ class RandAdjustContrastd(RandomizableTransform, MapTransform): `_ function. allow_missing_keys: don't raise exception if key is missing. + randomize_per_key: if True, draw independent random parameters for each key instead of sharing + them across all keys (e.g. for independent views in self-supervised learning). Defaults to False. """ backend = RandAdjustContrast.backend @@ -1015,9 +1047,11 @@ def __init__( invert_image: bool = False, retain_stats: bool = False, allow_missing_keys: bool = False, + randomize_per_key: bool = False, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) + self.randomize_per_key = randomize_per_key self.adjuster = RandAdjustContrast(gamma=gamma, prob=1.0, invert_image=invert_image, retain_stats=retain_stats) self.invert_image = invert_image @@ -1036,10 +1070,11 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N d[key] = convert_to_tensor(d[key], track_meta=get_track_meta()) return d - # all the keys share the same random gamma value - self.adjuster.randomize(None) + if not self.randomize_per_key: + # all the keys share the same random gamma value + self.adjuster.randomize(None) for key in self.key_iterator(d): - d[key] = self.adjuster(d[key], randomize=False) + d[key] = self.adjuster(d[key], randomize=self.randomize_per_key) return d @@ -1242,6 +1277,8 @@ class RandGaussianSmoothd(RandomizableTransform, MapTransform): see also :py:meth:`monai.networks.layers.GaussianFilter`. prob: probability of Gaussian smooth. allow_missing_keys: don't raise exception if key is missing. + randomize_per_key: if True, draw independent random parameters for each key instead of sharing + them across all keys (e.g. for independent views in self-supervised learning). Defaults to False. """ @@ -1256,9 +1293,11 @@ def __init__( approx: str = "erf", prob: float = 0.1, allow_missing_keys: bool = False, + randomize_per_key: bool = False, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) + self.randomize_per_key = randomize_per_key self.rand_smooth = RandGaussianSmooth( sigma_x=sigma_x, sigma_y=sigma_y, sigma_z=sigma_z, approx=approx, prob=1.0 ) @@ -1278,10 +1317,11 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N d[key] = convert_to_tensor(d[key], track_meta=get_track_meta()) return d - # all the keys share the same random sigma - self.rand_smooth.randomize(None) + if not self.randomize_per_key: + # all the keys share the same random sigma + self.rand_smooth.randomize(None) for key in self.key_iterator(d): - d[key] = self.rand_smooth(d[key], randomize=False) + d[key] = self.rand_smooth(d[key], randomize=self.randomize_per_key) return d @@ -1347,6 +1387,8 @@ class RandGaussianSharpend(RandomizableTransform, MapTransform): see also :py:meth:`monai.networks.layers.GaussianFilter`. prob: probability of Gaussian sharpen. allow_missing_keys: don't raise exception if key is missing. + randomize_per_key: if True, draw independent random parameters for each key instead of sharing + them across all keys (e.g. for independent views in self-supervised learning). Defaults to False. """ @@ -1365,9 +1407,11 @@ def __init__( approx: str = "erf", prob: float = 0.1, allow_missing_keys: bool = False, + randomize_per_key: bool = False, ): MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) + self.randomize_per_key = randomize_per_key self.rand_sharpen = RandGaussianSharpen( sigma1_x=sigma1_x, sigma1_y=sigma1_y, @@ -1395,10 +1439,11 @@ def __call__(self, data: dict[Hashable, NdarrayOrTensor]) -> dict[Hashable, Ndar d[key] = convert_to_tensor(d[key], track_meta=get_track_meta()) return d - # all the keys share the same random sigma1, sigma2, etc. - self.rand_sharpen.randomize(None) + if not self.randomize_per_key: + # all the keys share the same random sigma1, sigma2, etc. + self.rand_sharpen.randomize(None) for key in self.key_iterator(d): - d[key] = self.rand_sharpen(d[key], randomize=False) + d[key] = self.rand_sharpen(d[key], randomize=self.randomize_per_key) return d @@ -1415,6 +1460,8 @@ class RandHistogramShiftd(RandomizableTransform, MapTransform): control points selecting from range (min_value, max_value). prob: probability of histogram shift. allow_missing_keys: don't raise exception if key is missing. + randomize_per_key: if True, draw independent random parameters for each key instead of sharing + them across all keys (e.g. for independent views in self-supervised learning). Defaults to False. """ backend = RandHistogramShift.backend @@ -1425,9 +1472,11 @@ def __init__( num_control_points: tuple[int, int] | int = 10, prob: float = 0.1, allow_missing_keys: bool = False, + randomize_per_key: bool = False, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) + self.randomize_per_key = randomize_per_key self.shifter = RandHistogramShift(num_control_points=num_control_points, prob=1.0) def set_random_state( @@ -1445,10 +1494,11 @@ def __call__(self, data: dict[Hashable, NdarrayOrTensor]) -> dict[Hashable, Ndar d[key] = convert_to_tensor(d[key], track_meta=get_track_meta()) return d - # all the keys share the same random shift params - self.shifter.randomize(None) + if not self.randomize_per_key: + # all the keys share the same random shift params + self.shifter.randomize(None) for key in self.key_iterator(d): - d[key] = self.shifter(d[key], randomize=False) + d[key] = self.shifter(d[key], randomize=self.randomize_per_key) return d @@ -1476,6 +1526,8 @@ class RandGibbsNoised(RandomizableTransform, MapTransform): uniformly from the interval [a,b]. If a float is given, then the value of alpha will be sampled uniformly from the interval [0, alpha]. allow_missing_keys: do not raise exception if key is missing. + randomize_per_key: if True, draw independent random parameters for each key instead of sharing + them across all keys (e.g. for independent views in self-supervised learning). Defaults to False. """ backend = RandGibbsNoise.backend @@ -1486,9 +1538,11 @@ def __init__( prob: float = 0.1, alpha: float | Sequence[float] = (0.0, 1.0), allow_missing_keys: bool = False, + randomize_per_key: bool = False, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob=prob) + self.randomize_per_key = randomize_per_key self.rand_gibbs_noise = RandGibbsNoise(alpha=alpha, prob=1.0) def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> RandGibbsNoised: @@ -1504,10 +1558,11 @@ def __call__(self, data: dict[Hashable, NdarrayOrTensor]) -> dict[Hashable, Ndar d[key] = convert_to_tensor(d[key], track_meta=get_track_meta()) return d - # all the keys share the same random noise params - self.rand_gibbs_noise.randomize(None) + if not self.randomize_per_key: + # all the keys share the same random noise params + self.rand_gibbs_noise.randomize(None) for key in self.key_iterator(d): - d[key] = self.rand_gibbs_noise(d[key], randomize=False) + d[key] = self.rand_gibbs_noise(d[key], randomize=self.randomize_per_key) return d diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 51ad0435fc..c987a86a68 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -1657,16 +1657,26 @@ class RandAxisFlipd(RandomizableTransform, MapTransform, InvertibleTransform, La allow_missing_keys: don't raise exception if key is missing. lazy: a flag to indicate whether this transform should execute lazily or not. Defaults to False + randomize_per_key: if True, draw an independent random axis for each key instead of sharing one + across all keys (e.g. for independent views in self-supervised learning). Note this breaks the + spatial correspondence between keys, so keep it False for aligned data such as image/label + pairs. Defaults to False. """ backend = RandAxisFlip.backend def __init__( - self, keys: KeysCollection, prob: float = 0.1, allow_missing_keys: bool = False, lazy: bool = False + self, + keys: KeysCollection, + prob: float = 0.1, + allow_missing_keys: bool = False, + lazy: bool = False, + randomize_per_key: bool = False, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) LazyTransform.__init__(self, lazy=lazy) + self.randomize_per_key = randomize_per_key self.flipper = RandAxisFlip(prob=1.0, lazy=lazy) @LazyTransform.lazy.setter # type: ignore @@ -1699,13 +1709,14 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = No self.randomize(None) - # all the keys share the same random selected axis - self.flipper.randomize(d[first_key]) + if not self.randomize_per_key: + # all the keys share the same random selected axis + self.flipper.randomize(d[first_key]) lazy_ = self.lazy if lazy is None else lazy for key in self.key_iterator(d): if self._do_transform: - d[key] = self.flipper(d[key], randomize=False, lazy=lazy_) + d[key] = self.flipper(d[key], randomize=self.randomize_per_key, lazy=lazy_) else: d[key] = convert_to_tensor(d[key], track_meta=get_track_meta()) self.push_transform(d[key], replace=True, lazy=lazy_) @@ -1850,6 +1861,10 @@ class RandRotated(RandomizableTransform, MapTransform, InvertibleTransform, Lazy allow_missing_keys: don't raise exception if key is missing. lazy: a flag to indicate whether this transform should execute lazily or not. Defaults to False + randomize_per_key: if True, draw independent random parameters for each key instead of sharing + them across all keys (e.g. for independent views in self-supervised learning). Note this breaks + the spatial correspondence between keys, so keep it False for aligned data such as image/label + pairs. Defaults to False. """ backend = RandRotate.backend @@ -1868,10 +1883,12 @@ def __init__( dtype: Sequence[DtypeLike | torch.dtype] | DtypeLike | torch.dtype = np.float32, allow_missing_keys: bool = False, lazy: bool = False, + randomize_per_key: bool = False, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) LazyTransform.__init__(self, lazy=lazy) + self.randomize_per_key = randomize_per_key self.rand_rotate = RandRotate( range_x=range_x, range_y=range_y, range_z=range_z, prob=1.0, keep_size=keep_size, lazy=lazy ) @@ -1906,8 +1923,9 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = No d = dict(data) self.randomize(None) - # all the keys share the same random rotate angle - self.rand_rotate.randomize() + if not self.randomize_per_key: + # all the keys share the same random rotate angle + self.rand_rotate.randomize() lazy_ = self.lazy if lazy is None else lazy for key, mode, padding_mode, align_corners, dtype in self.key_iterator( @@ -1920,7 +1938,7 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = No padding_mode=padding_mode, align_corners=align_corners, dtype=dtype, - randomize=False, + randomize=self.randomize_per_key, lazy=lazy_, ) else: @@ -2076,6 +2094,10 @@ class RandZoomd(RandomizableTransform, MapTransform, InvertibleTransform, LazyTr allow_missing_keys: don't raise exception if key is missing. lazy: a flag to indicate whether this transform should execute lazily or not. Defaults to False + randomize_per_key: if True, draw independent random parameters for each key instead of sharing + them across all keys (e.g. for independent views in self-supervised learning). Note this breaks + the spatial correspondence between keys, so keep it False for aligned data such as image/label + pairs. Defaults to False. kwargs: other args for `np.pad` API, note that `np.pad` treats channel dimension as the first dimension. more details: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html """ @@ -2095,11 +2117,13 @@ def __init__( keep_size: bool = True, allow_missing_keys: bool = False, lazy: bool = False, + randomize_per_key: bool = False, **kwargs, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) LazyTransform.__init__(self, lazy=lazy) + self.randomize_per_key = randomize_per_key self.rand_zoom = RandZoom( prob=1.0, min_zoom=min_zoom, max_zoom=max_zoom, keep_size=keep_size, lazy=lazy, **kwargs ) @@ -2139,8 +2163,9 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = No self.randomize(None) - # all the keys share the same random zoom factor - self.rand_zoom.randomize(d[first_key]) + if not self.randomize_per_key: + # all the keys share the same random zoom factor + self.rand_zoom.randomize(d[first_key]) lazy_ = self.lazy if lazy is None else lazy for key, mode, padding_mode, align_corners, dtype in self.key_iterator( @@ -2153,7 +2178,7 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = No padding_mode=padding_mode, align_corners=align_corners, dtype=dtype, - randomize=False, + randomize=self.randomize_per_key, lazy=lazy_, ) else: @@ -2250,6 +2275,7 @@ def __init__( padding_mode: str = GridSamplePadMode.BORDER, device: torch.device | None = None, allow_missing_keys: bool = False, + randomize_per_key: bool = False, ) -> None: """ Args: @@ -2275,10 +2301,15 @@ def __init__( It also can be a sequence, each element corresponds to a key in ``keys``. device: device on which the tensor will be allocated. allow_missing_keys: don't raise exception if key is missing. + randomize_per_key: if True, draw independent random parameters for each key instead of sharing + them across all keys (e.g. for independent views in self-supervised learning). Note this + breaks the spatial correspondence between keys, so keep it False for aligned data such as + image/label pairs. Defaults to False. """ MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) + self.randomize_per_key = randomize_per_key self.rand_grid_distortion = RandGridDistortion( num_cells=num_cells, prob=1.0, distort_limit=distort_limit, device=device ) @@ -2314,10 +2345,13 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torc return out if isinstance(d[first_key], MetaTensor) and d[first_key].pending_operations: # type: ignore warnings.warn(f"data['{first_key}'] has pending operations, transform may return incorrect results.") - self.rand_grid_distortion.randomize(d[first_key].shape[1:]) + if not self.randomize_per_key: + self.rand_grid_distortion.randomize(d[first_key].shape[1:]) for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode): - d[key] = self.rand_grid_distortion(d[key], mode=mode, padding_mode=padding_mode, randomize=False) + d[key] = self.rand_grid_distortion( + d[key], mode=mode, padding_mode=padding_mode, randomize=self.randomize_per_key + ) return d diff --git a/tests/transforms/test_rand_per_key.py b/tests/transforms/test_rand_per_key.py new file mode 100644 index 0000000000..9fe7277bfe --- /dev/null +++ b/tests/transforms/test_rand_per_key.py @@ -0,0 +1,82 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.transforms.intensity.dictionary import ( + RandAdjustContrastd, + RandBiasFieldd, + RandGaussianNoised, + RandGaussianSharpend, + RandGaussianSmoothd, + RandGibbsNoised, + RandHistogramShiftd, + RandScaleIntensityd, + RandScaleIntensityFixedMeand, + RandShiftIntensityd, + RandStdShiftIntensityd, +) +from monai.transforms.spatial.dictionary import RandAxisFlipd, RandGridDistortiond, RandRotated, RandZoomd +from tests.test_utils import assert_allclose + +KEYS = ["img1", "img2"] + +TESTS = [ + (RandGaussianNoised, {}), + (RandShiftIntensityd, {"offsets": 0.5}), + (RandStdShiftIntensityd, {"factors": 0.5}), + (RandScaleIntensityd, {"factors": 0.5}), + (RandScaleIntensityFixedMeand, {"factors": 0.5}), + (RandBiasFieldd, {}), + (RandAdjustContrastd, {}), + (RandGaussianSmoothd, {}), + (RandGaussianSharpend, {}), + (RandHistogramShiftd, {"num_control_points": (5, 20)}), + (RandGibbsNoised, {}), + (RandAxisFlipd, {}), + (RandRotated, {"range_x": 1.0, "range_y": 1.0, "range_z": 1.0}), + (RandZoomd, {"min_zoom": 0.7, "max_zoom": 1.3}), + (RandGridDistortiond, {"num_cells": 3, "distort_limit": 0.2}), +] + + +class TestRandPerKey(unittest.TestCase): + @parameterized.expand([(cls.__name__, cls, kwargs) for cls, kwargs in TESTS]) + def test_shared_default(self, _, cls, kwargs): + t = cls(keys=KEYS, prob=1.0, **kwargs) + t.set_random_state(0) + img = torch.rand(1, 8, 8, 8) + 1.0 + out = t({k: img.clone() for k in KEYS}) + assert_allclose(out["img1"], out["img2"], type_test=False) + + @parameterized.expand([(cls.__name__, cls, kwargs) for cls, kwargs in TESTS]) + def test_independent_per_key(self, _, cls, kwargs): + # independent draws may coincide for a single seed (e.g. RandAxisFlipd's discrete axis), + # so only require that some deterministic seed yields divergent per-key outputs + img = torch.rand(1, 8, 8, 8) + 1.0 + differs = False + for seed in range(10): + t = cls(keys=KEYS, prob=1.0, randomize_per_key=True, **kwargs) + t.set_random_state(seed) + out = t({k: img.clone() for k in KEYS}) + if not torch.allclose(out["img1"], out["img2"]): + differs = True + break + self.assertTrue(differs) + + +if __name__ == "__main__": + unittest.main()