SwinUNETR + HyenaND: subquadratic alternative to windowed self-attention#8958
SwinUNETR + HyenaND: subquadratic alternative to windowed self-attention#8958farhadrgh wants to merge 8 commits into
Conversation
Signed-off-by: Farhad Ramezanghorbani <farhadr@nvidia.com>
Signed-off-by: Farhad Ramezanghorbani <farhadr@nvidia.com>
Signed-off-by: Farhad Ramezanghorbani <farhadr@nvidia.com>
Signed-off-by: Farhad Ramezanghorbani <farhadr@nvidia.com>
Signed-off-by: Farhad Ramezanghorbani <farhadr@nvidia.com>
Signed-off-by: Farhad Ramezanghorbani <farhadr@nvidia.com>
for more information, see https://pre-commit.ci
📝 WalkthroughWalkthroughAdds HyenaND-based sequence modeling blocks to MONAI: Estimated code review effort🎯 5 (Critical) | ⏱️ ~120 minutes 🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 7
🧹 Nitpick comments (6)
docs/source/networks.rst (1)
132-148: 📐 Maintainability & Code Quality | 🔵 Trivial | ⚡ Quick winConsider documenting
is_nvsubquadratic_available().The function is public API (
__all__) and lets users detect the optional dependency at runtime. Add anautofunctionentry near the Hyena block classes:`Hyena Utilities` ~~~~~~~~~~~~~~~~~ .. autofunction:: monai.networks.blocks.hyena.is_nvsubquadratic_available🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@docs/source/networks.rst` around lines 132 - 148, Add documentation for the public helper is_nvsubquadratic_available() in the Hyena networks docs, since it is part of the exposed API and users need a runtime way to detect the optional dependency. Update the Hyena section in networks.rst near HyenaMixer and HyenaTransformerBlock by adding a Hyena Utilities subsection with an autofunction entry for monai.networks.blocks.hyena.is_nvsubquadratic_available, keeping it grouped with the other Hyena-related APIs.CHANGELOG.md (1)
7-11: 📐 Maintainability & Code Quality | 🔵 Trivial | ⚡ Quick winConsider adding
is_nvsubquadratic_available()andload_fromskip behavior to the changelog.The entry is accurate but omits two user-visible behaviors:
monai.networks.blocks.hyena.is_nvsubquadratic_available()for runtime optional-dependency detection.SwinUNETR.load_fromnow skips Hyena stages with a warning rather than failing.Both are worth noting for users integrating Hyena conditionally.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@CHANGELOG.md` around lines 7 - 11, The Hyena changelog entry is missing two user-visible behaviors that should be called out for users integrating the optional dependency conditionally. Update the `CHANGELOG.md` “Added” section to mention `monai.networks.blocks.hyena.is_nvsubquadratic_available()` as the runtime availability check, and note that `SwinUNETR.load_from` now skips Hyena stages with a warning instead of failing. Keep the wording concise and tie both items to the existing Hyena additions so readers can find the relevant APIs easily.monai/networks/blocks/hyena.py (1)
95-135: 📐 Maintainability & Code Quality | 🔵 Trivial | ⚡ Quick winAdd the required Google-style docstrings.
Several new definitions lack docstrings or Args/Returns/Raises sections. As per path instructions, “Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings.”
Also applies to: 149-177, 189-217, 389-413, 517-548
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@monai/networks/blocks/hyena.py` around lines 95 - 135, Add Google-style docstrings to the newly introduced definitions in hyena.py, especially the forward methods and related helpers referenced by the diff, so each function/class clearly documents its purpose, all parameters/variables, return value, and any raised exceptions. Update the affected symbols in the Hyena block implementations and any other new definitions called out in the review to include the required Args and Returns/Raises sections, matching the repository’s docstring conventions.Source: Path instructions
tests/networks/blocks/test_hyena_block.py (1)
327-337: 📐 Maintainability & Code Quality | 🔵 Trivial | ⚡ Quick winAssert the FFT short-conv wiring.
These tests still pass if
HyenaMixerignoresuse_fft_short_convorshort_conv_fft_chunk_size, because plainnn.Conv3dpreserves the same shape. Add a type/config assertion on the constructed short-conv module so the branch inmonai/networks/blocks/hyena.py:274-368is actually covered. As per path instructions,Ensure new or modified definitions will be covered by existing or new unit tests.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/networks/blocks/test_hyena_block.py` around lines 327 - 337, The HyenaMixer 3D tests only verify output shape, so they still pass even if the FFT short-conv path is never used. Update the tests around HyenaMixer construction to assert the short-conv module type/config when use_fft_short_conv and short_conv_fft_chunk_size are set, so the branch in HyenaMixer’s short-conv wiring is explicitly exercised and covered by unit tests.Source: Path instructions
tests/networks/nets/test_hyena_nd_unetr.py (1)
40-95: 📐 Maintainability & Code Quality | 🔵 Trivial | ⚡ Quick winAdd the missing no-dependency constructor test.
Every constructor-contract case is skipped when
nvsubquadraticis absent, butHyenaNDUNETRdocuments anImportErrorpath inmonai/networks/nets/hyena_nd_unetr.py:59-165. Add a@skipUnless(not HAS_NVSUBQ, ...)case so that contract stays covered. As per path instructions,Ensure new or modified definitions will be covered by existing or new unit tests.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/networks/nets/test_hyena_nd_unetr.py` around lines 40 - 95, Add a missing constructor-contract test for the no-dependency path in TestHyenaNDUNETRConstructorContract so coverage still exists when HAS_NVSUBQ is false. Create a new test alongside the existing HyenaNDUNETR constructor tests that is skipped unless nvsubquadratic is absent and asserts the documented ImportError behavior from HyenaNDUNETR/__init__. Keep the focus on the constructor symbols HyenaNDUNETR and TestHyenaNDUNETRConstructorContract, and ensure this path is exercised by a unit test rather than only the dependency-present cases.Source: Path instructions
monai/networks/nets/swin_unetr.py (1)
331-340: 📐 Maintainability & Code Quality | 🔵 Trivial | ⚡ Quick winAdd Google-style sections to
load_from.This public method now has custom checkpoint-loading behavior, but
weights, return value, and expected exceptions are not documented.Proposed docstring update
def load_from(self, weights): """Load pretrained Swin weights into the matching submodules. When a stage uses :class:`HyenaTransformerBlock` instead of :class:`SwinTransformerBlock`, the per-block ``load_from`` call is skipped for that stage and a warning is issued -- HyenaND has a different parameter layout and there are no compatible attention weights to copy. PatchMerging downsample weights are still loaded for all stages (the downsample layer is the same in both code paths). + + Args: + weights: Checkpoint mapping containing a ``"state_dict"`` with Swin + pretrained parameter tensors. + + Returns: + None. + + Raises: + KeyError: If an expected checkpoint key is absent. + RuntimeError: If a checkpoint tensor shape is incompatible. """As per path instructions, “Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings.”
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@monai/networks/nets/swin_unetr.py` around lines 331 - 340, The public method load_from in SwinUNETR needs a full Google-style docstring update because it now has custom checkpoint-loading behavior without documented parameters, return value, or exceptions. Expand the existing docstring to add Args for weights (and any other inputs used by load_from), a Returns section if it returns a value, and a Raises section for any expected exceptions, while keeping the current summary about Swin and HyenaTransformerBlock loading behavior.Source: Path instructions
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@docs/source/installation.md`:
- Around line 263-266: Update the `hyena` extra documentation to point the
`nvsubquadratic` hyperlink at the correct NVIDIA-BioNeMo repository instead of
the dead NVIDIA URL. Locate the `hyena` installation description in the docs and
replace the existing link target so readers using `HyenaNDUNETR`, `HyenaMixer`,
and `HyenaTransformerBlock` are directed to the valid project page.
In `@monai/networks/blocks/hyena.py`:
- Around line 51-56: The availability check in the Hyena module currently
depends only on the LazyConfig import, so it can report nvsubquadratic as
available even when instantiate, Hyena, CKConvND, SIRENKernelND, or
GaussianModulationND fail to import. Update the import setup in hyena.py so each
optional_import contributes a flag, and change is_nvsubquadratic_available() to
require all of those symbols to be present before returning true.
- Around line 95-135: `Hyena.forward` currently hardcodes the FFT crop around
`kernel_shape // 2`, so `DepthwiseFFTConv{2,3}d` ignores its `self.padding`
setting and can produce incorrect outputs for non-default padding or even
kernels. Update the crop logic in `forward()` (including the chunked path) to
use the module’s padding value when building `slices`, and keep the FFT result
aligned with Conv{2,3}d semantics for both output shape and values.
In `@monai/networks/nets/swin_unetr.py`:
- Around line 363-367: The warning emitted in `load_from` for
`HyenaTransformerBlock` should set an explicit stacklevel so the caller sees the
warning at their call site instead of inside this helper. Update the existing
warnings.warn call in `swin_unetr.py` to include a stacklevel argument, keeping
the same message and using the `load_from` path where `layer_name` and
`HyenaTransformerBlock` are handled.
- Around line 956-1001: The Hyena path still pays the cost of attention-mask
construction even though HyenaTransformerBlock.forward ignores mask_matrix.
Update the stage forward logic that uses self.use_hyena and compute_mask(...) so
mask generation is skipped entirely for Hyena stages, while preserving the
existing SwinTransformerBlock path for non-Hyena stages. Keep the change
localized to the stage module that initializes self.blocks and dispatches to the
block forward calls.
In `@tests/networks/nets/test_hyena_nd_unetr.py`:
- Around line 76-84: The duplicate hyena_stages kwarg assertion in
test_duplicate_hyena_stages_kwarg_rejected is unreachable because Python raises
TypeError at the call site before HyenaNDUNETR.__init__ can inspect kwargs.
Remove this test or rewrite it to cover a reachable duplicate-argument path, and
if keeping a kwargs validation check in HyenaNDUNETR.__init__, add a separate
test that passes hyena_stages only through kwargs so the branch can actually
execute.
In `@tests/networks/nets/test_swin_unetr.py`:
- Around line 170-187: The golden check in test_default_path_unchanged is too
strict because it hashes raw CUDA bytes, so replace the SHA256 comparison with a
tolerance-based validation such as torch.testing.assert_close against a stored
reference tensor or another numeric invariant. Keep the test focused on the
SwinUNETR default forward path and preserve the existing setup in
test_default_path_unchanged, but avoid byte-level equality that can fail from
harmless GPU/PyTorch drift.
---
Nitpick comments:
In `@CHANGELOG.md`:
- Around line 7-11: The Hyena changelog entry is missing two user-visible
behaviors that should be called out for users integrating the optional
dependency conditionally. Update the `CHANGELOG.md` “Added” section to mention
`monai.networks.blocks.hyena.is_nvsubquadratic_available()` as the runtime
availability check, and note that `SwinUNETR.load_from` now skips Hyena stages
with a warning instead of failing. Keep the wording concise and tie both items
to the existing Hyena additions so readers can find the relevant APIs easily.
In `@docs/source/networks.rst`:
- Around line 132-148: Add documentation for the public helper
is_nvsubquadratic_available() in the Hyena networks docs, since it is part of
the exposed API and users need a runtime way to detect the optional dependency.
Update the Hyena section in networks.rst near HyenaMixer and
HyenaTransformerBlock by adding a Hyena Utilities subsection with an
autofunction entry for monai.networks.blocks.hyena.is_nvsubquadratic_available,
keeping it grouped with the other Hyena-related APIs.
In `@monai/networks/blocks/hyena.py`:
- Around line 95-135: Add Google-style docstrings to the newly introduced
definitions in hyena.py, especially the forward methods and related helpers
referenced by the diff, so each function/class clearly documents its purpose,
all parameters/variables, return value, and any raised exceptions. Update the
affected symbols in the Hyena block implementations and any other new
definitions called out in the review to include the required Args and
Returns/Raises sections, matching the repository’s docstring conventions.
In `@monai/networks/nets/swin_unetr.py`:
- Around line 331-340: The public method load_from in SwinUNETR needs a full
Google-style docstring update because it now has custom checkpoint-loading
behavior without documented parameters, return value, or exceptions. Expand the
existing docstring to add Args for weights (and any other inputs used by
load_from), a Returns section if it returns a value, and a Raises section for
any expected exceptions, while keeping the current summary about Swin and
HyenaTransformerBlock loading behavior.
In `@tests/networks/blocks/test_hyena_block.py`:
- Around line 327-337: The HyenaMixer 3D tests only verify output shape, so they
still pass even if the FFT short-conv path is never used. Update the tests
around HyenaMixer construction to assert the short-conv module type/config when
use_fft_short_conv and short_conv_fft_chunk_size are set, so the branch in
HyenaMixer’s short-conv wiring is explicitly exercised and covered by unit
tests.
In `@tests/networks/nets/test_hyena_nd_unetr.py`:
- Around line 40-95: Add a missing constructor-contract test for the
no-dependency path in TestHyenaNDUNETRConstructorContract so coverage still
exists when HAS_NVSUBQ is false. Create a new test alongside the existing
HyenaNDUNETR constructor tests that is skipped unless nvsubquadratic is absent
and asserts the documented ImportError behavior from HyenaNDUNETR/__init__. Keep
the focus on the constructor symbols HyenaNDUNETR and
TestHyenaNDUNETRConstructorContract, and ensure this path is exercised by a unit
test rather than only the dependency-present cases.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: cb613c72-ab77-43ad-a849-6b6826f9760c
📒 Files selected for processing (14)
CHANGELOG.mddocs/source/installation.mddocs/source/networks.rstmonai/networks/blocks/__init__.pymonai/networks/blocks/hyena.pymonai/networks/nets/__init__.pymonai/networks/nets/hyena_nd_unetr.pymonai/networks/nets/swin_unetr.pyrequirements-dev.txtsetup.cfgtests/min_tests.pytests/networks/blocks/test_hyena_block.pytests/networks/nets/test_hyena_nd_unetr.pytests/networks/nets/test_swin_unetr.py
| The `hyena` extra pulls in [`nvsubquadratic`](https://github.com/NVIDIA/nvSubquadratic), | ||
| required by `HyenaNDUNETR` / `HyenaMixer` / `HyenaTransformerBlock` (subquadratic | ||
| O(N log N) alternatives to windowed self-attention). Install with | ||
| `pip install 'monai[hyena]'`. |
There was a problem hiding this comment.
🎯 Functional Correctness | 🟡 Minor | ⚡ Quick win
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Verify GitHub repository URL validity and canonical casing
curl -s -o /dev/null -w "%{http_code}" https://github.com/NVIDIA/nvSubquadratic
curl -s -o /dev/null -w "%{http_code}" https://github.com/NVIDIA/nvsubquadraticRepository: Project-MONAI/MONAI
Length of output: 163
🌐 Web query:
NVIDIA nvSubquadratic GitHub repository canonical name
💡 Result:
The canonical GitHub repository for the NVIDIA library nvSubquadratic is located at https://github.com/NVIDIA-BioNeMo/nvSubquadratic [1][2]. It is important to distinguish this library from other similarly named projects or entities, such as the AI startup "Subquadratic" (which develops the "SubQ" model) [3][4], or other unrelated NVIDIA repositories like "optix-subd" [5] and "subquadratic-ops-torch" [6]. The nvSubquadratic library is a PyTorch-native toolkit developed by NVIDIA for subquadratic alternatives to traditional quadratic attention mechanisms, specifically utilizing optimized CUDA kernels for operators like Hyena [1][2].
Citations:
- 1: https://pypi.org/project/nvsubquadratic/
- 2: https://pypi.org/project/nvsubquadratic/0.1.0rc0/
- 3: https://www.technologyreview.com/2026/06/19/1139313/a-startup-claims-it-broke-through-a-bottleneck-thats-holding-back-llms/
- 4: https://www.youtube.com/watch?v=uzkTAT81FxA
- 5: https://github.com/nvidia/optix-subd
- 6: https://pypi.org/project/subquadratic-ops-torch-cu12/
Fix the nvSubquadratic link — https://github.com/NVIDIA/nvSubquadratic is dead; use https://github.com/NVIDIA-BioNeMo/nvSubquadratic instead.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@docs/source/installation.md` around lines 263 - 266, Update the `hyena` extra
documentation to point the `nvsubquadratic` hyperlink at the correct
NVIDIA-BioNeMo repository instead of the dead NVIDIA URL. Locate the `hyena`
installation description in the docs and replace the existing link target so
readers using `HyenaNDUNETR`, `HyenaMixer`, and `HyenaTransformerBlock` are
directed to the valid project page.
| _LazyConfig, _has_nvsubq = optional_import("nvsubquadratic.lazy_config", name="LazyConfig") | ||
| _instantiate, _ = optional_import("nvsubquadratic.lazy_config", name="instantiate") | ||
| _Hyena, _ = optional_import("nvsubquadratic.modules.hyena_nd", name="Hyena") | ||
| _CKConvND, _ = optional_import("nvsubquadratic.modules.ckconv_nd", name="CKConvND") | ||
| _SIRENKernelND, _ = optional_import("nvsubquadratic.modules.kernels_nd", name="SIRENKernelND") | ||
| _GaussianModulationND, _ = optional_import("nvsubquadratic.modules.masks_nd", name="GaussianModulationND") |
There was a problem hiding this comment.
🩺 Stability & Availability | 🟡 Minor | ⚡ Quick win
🧩 Analysis chain
🏁 Script executed:
ast-grep outline monai/networks/blocks/hyena.py --view expandedRepository: Project-MONAI/MONAI
Length of output: 1102
🏁 Script executed:
sed -n '1,220p' monai/networks/blocks/hyena.pyRepository: Project-MONAI/MONAI
Length of output: 9030
🏁 Script executed:
sed -n '220,380p' monai/networks/blocks/hyena.pyRepository: Project-MONAI/MONAI
Length of output: 7391
🏁 Script executed:
rg -n "is_nvsubquadratic_available\\(" monaiRepository: Project-MONAI/MONAI
Length of output: 234
🏁 Script executed:
sed -n '380,560p' monai/networks/blocks/hyena.pyRepository: Project-MONAI/MONAI
Length of output: 6867
Gate availability on every optional nvsubquadratic import. is_nvsubquadratic_available() only tracks LazyConfig, so it can still return true when instantiate, Hyena, CKConvND, SIRENKernelND, or GaussianModulationND is missing; fold all of those import flags into the helper.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@monai/networks/blocks/hyena.py` around lines 51 - 56, The availability check
in the Hyena module currently depends only on the LazyConfig import, so it can
report nvsubquadratic as available even when instantiate, Hyena, CKConvND,
SIRENKernelND, or GaussianModulationND fail to import. Update the import setup
in hyena.py so each optional_import contributes a flag, and change
is_nvsubquadratic_available() to require all of those symbols to be present
before returning true.
| def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
| spatial = x.shape[2:] | ||
| kernel_shape = self.weight.shape[2:] # type: ignore[attr-defined] | ||
| fft_dims = tuple(range(-self._spatial_dims, 0)) | ||
| fft_size = [s + k - 1 for s, k in zip(spatial, kernel_shape)] | ||
| in_dtype = x.dtype | ||
|
|
||
| slices = (slice(None), slice(None)) + tuple( | ||
| slice(k // 2, k // 2 + s) for s, k in zip(spatial, kernel_shape) | ||
| ) | ||
|
|
||
| chunk = getattr(self, "fft_chunk_size", 0) | ||
| if chunk > 0 and x.shape[1] > chunk: | ||
| parts = [] | ||
| for c0 in range(0, x.shape[1], chunk): | ||
| c1 = min(c0 + chunk, x.shape[1]) | ||
| xc = x[:, c0:c1].float() | ||
| kc = self.weight[c0:c1].squeeze(1).float() # type: ignore[attr-defined] | ||
| kc = kc.flip(list(range(1, self._spatial_dims + 1))) | ||
| xc_fft = torch.fft.rfftn(xc, s=fft_size, dim=fft_dims) | ||
| kc_fft = torch.fft.rfftn(kc, s=fft_size, dim=fft_dims) | ||
| out_fft = xc_fft * kc_fft.unsqueeze(0) | ||
| del xc_fft, kc_fft | ||
| out_c = torch.fft.irfftn(out_fft, s=fft_size, dim=fft_dims) | ||
| del out_fft | ||
| parts.append(out_c[slices].to(in_dtype)) | ||
| del out_c | ||
| return torch.cat(parts, dim=1) | ||
|
|
||
| x_f32 = x.float() | ||
| k_f32 = self.weight.squeeze(1).float() # type: ignore[attr-defined] | ||
| # PyTorch ``F.conv*`` computes cross-correlation; FFT computes convolution. | ||
| # Flip the kernel so the FFT output matches ``Conv{2,3}d`` exactly. | ||
| k_f32 = k_f32.flip(list(range(1, self._spatial_dims + 1))) | ||
|
|
||
| x_fft = torch.fft.rfftn(x_f32, s=fft_size, dim=fft_dims) | ||
| k_fft = torch.fft.rfftn(k_f32, s=fft_size, dim=fft_dims) | ||
|
|
||
| out_fft = x_fft * k_fft.unsqueeze(0) | ||
| out = torch.fft.irfftn(out_fft, s=fft_size, dim=fft_dims) | ||
| return out[slices].to(in_dtype) |
There was a problem hiding this comment.
🎯 Functional Correctness | 🟠 Major | ⚡ Quick win
Honor self.padding in the FFT crop.
DepthwiseFFTConv{2,3}d accepts padding, but forward() always crops as if padding == kernel_size // 2. Non-default padding, even kernels, or future callers expecting Conv semantics will return the wrong shape/values.
Possible fix direction
- slices = (slice(None), slice(None)) + tuple(
- slice(k // 2, k // 2 + s) for s, k in zip(spatial, kernel_shape)
- )
+ padding = self.padding # type: ignore[attr-defined]
+ if isinstance(padding, int):
+ padding = (padding,) * self._spatial_dims
+ out_spatial = [s + (2 * p) - k + 1 for s, p, k in zip(spatial, padding, kernel_shape, strict=True)]
+ slices = (
+ slice(None),
+ slice(None),
+ *(slice(k - 1 - p, k - 1 - p + out) for p, k, out in zip(padding, kernel_shape, out_spatial, strict=True)),
+ )📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| spatial = x.shape[2:] | |
| kernel_shape = self.weight.shape[2:] # type: ignore[attr-defined] | |
| fft_dims = tuple(range(-self._spatial_dims, 0)) | |
| fft_size = [s + k - 1 for s, k in zip(spatial, kernel_shape)] | |
| in_dtype = x.dtype | |
| slices = (slice(None), slice(None)) + tuple( | |
| slice(k // 2, k // 2 + s) for s, k in zip(spatial, kernel_shape) | |
| ) | |
| chunk = getattr(self, "fft_chunk_size", 0) | |
| if chunk > 0 and x.shape[1] > chunk: | |
| parts = [] | |
| for c0 in range(0, x.shape[1], chunk): | |
| c1 = min(c0 + chunk, x.shape[1]) | |
| xc = x[:, c0:c1].float() | |
| kc = self.weight[c0:c1].squeeze(1).float() # type: ignore[attr-defined] | |
| kc = kc.flip(list(range(1, self._spatial_dims + 1))) | |
| xc_fft = torch.fft.rfftn(xc, s=fft_size, dim=fft_dims) | |
| kc_fft = torch.fft.rfftn(kc, s=fft_size, dim=fft_dims) | |
| out_fft = xc_fft * kc_fft.unsqueeze(0) | |
| del xc_fft, kc_fft | |
| out_c = torch.fft.irfftn(out_fft, s=fft_size, dim=fft_dims) | |
| del out_fft | |
| parts.append(out_c[slices].to(in_dtype)) | |
| del out_c | |
| return torch.cat(parts, dim=1) | |
| x_f32 = x.float() | |
| k_f32 = self.weight.squeeze(1).float() # type: ignore[attr-defined] | |
| # PyTorch ``F.conv*`` computes cross-correlation; FFT computes convolution. | |
| # Flip the kernel so the FFT output matches ``Conv{2,3}d`` exactly. | |
| k_f32 = k_f32.flip(list(range(1, self._spatial_dims + 1))) | |
| x_fft = torch.fft.rfftn(x_f32, s=fft_size, dim=fft_dims) | |
| k_fft = torch.fft.rfftn(k_f32, s=fft_size, dim=fft_dims) | |
| out_fft = x_fft * k_fft.unsqueeze(0) | |
| out = torch.fft.irfftn(out_fft, s=fft_size, dim=fft_dims) | |
| return out[slices].to(in_dtype) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| spatial = x.shape[2:] | |
| kernel_shape = self.weight.shape[2:] # type: ignore[attr-defined] | |
| fft_dims = tuple(range(-self._spatial_dims, 0)) | |
| fft_size = [s + k - 1 for s, k in zip(spatial, kernel_shape)] | |
| in_dtype = x.dtype | |
| padding = self.padding # type: ignore[attr-defined] | |
| if isinstance(padding, int): | |
| padding = (padding,) * self._spatial_dims | |
| out_spatial = [s + (2 * p) - k + 1 for s, p, k in zip(spatial, padding, kernel_shape, strict=True)] | |
| slices = ( | |
| slice(None), | |
| slice(None), | |
| *(slice(k - 1 - p, k - 1 - p + out) for p, k, out in zip(padding, kernel_shape, out_spatial, strict=True)), | |
| ) | |
| chunk = getattr(self, "fft_chunk_size", 0) | |
| if chunk > 0 and x.shape[1] > chunk: | |
| parts = [] | |
| for c0 in range(0, x.shape[1], chunk): | |
| c1 = min(c0 + chunk, x.shape[1]) | |
| xc = x[:, c0:c1].float() | |
| kc = self.weight[c0:c1].squeeze(1).float() # type: ignore[attr-defined] | |
| kc = kc.flip(list(range(1, self._spatial_dims + 1))) | |
| xc_fft = torch.fft.rfftn(xc, s=fft_size, dim=fft_dims) | |
| kc_fft = torch.fft.rfftn(kc, s=fft_size, dim=fft_dims) | |
| out_fft = xc_fft * kc_fft.unsqueeze(0) | |
| del xc_fft, kc_fft | |
| out_c = torch.fft.irfftn(out_fft, s=fft_size, dim=fft_dims) | |
| del out_fft | |
| parts.append(out_c[slices].to(in_dtype)) | |
| del out_c | |
| return torch.cat(parts, dim=1) | |
| x_f32 = x.float() | |
| k_f32 = self.weight.squeeze(1).float() # type: ignore[attr-defined] | |
| # PyTorch ``F.conv*`` computes cross-correlation; FFT computes convolution. | |
| # Flip the kernel so the FFT output matches ``Conv{2,3}d`` exactly. | |
| k_f32 = k_f32.flip(list(range(1, self._spatial_dims + 1))) | |
| x_fft = torch.fft.rfftn(x_f32, s=fft_size, dim=fft_dims) | |
| k_fft = torch.fft.rfftn(k_f32, s=fft_size, dim=fft_dims) | |
| out_fft = x_fft * k_fft.unsqueeze(0) | |
| out = torch.fft.irfftn(out_fft, s=fft_size, dim=fft_dims) | |
| return out[slices].to(in_dtype) |
🧰 Tools
🪛 Ruff (0.15.18)
[warning] 99-99: zip() without an explicit strict= parameter
Add explicit value for parameter strict=
(B905)
[warning] 102-104: Consider iterable unpacking instead of concatenation
Replace with iterable unpacking
(RUF005)
[warning] 103-103: zip() without an explicit strict= parameter
Add explicit value for parameter strict=
(B905)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@monai/networks/blocks/hyena.py` around lines 95 - 135, `Hyena.forward`
currently hardcodes the FFT crop around `kernel_shape // 2`, so
`DepthwiseFFTConv{2,3}d` ignores its `self.padding` setting and can produce
incorrect outputs for non-default padding or even kernels. Update the crop logic
in `forward()` (including the chunked path) to use the module’s padding value
when building `slices`, and keep the FFT result aligned with Conv{2,3}d
semantics for both output shape and values.
| warnings.warn( | ||
| f"Skipping {layer_name} block weights: stage uses HyenaTransformerBlock, " | ||
| "which has no compatible Swin attention weights. Blocks remain at their " | ||
| "random initialization." | ||
| ) |
There was a problem hiding this comment.
📐 Maintainability & Code Quality | 🟡 Minor | ⚡ Quick win
Set stacklevel on the warning.
Without it, callers see the warning location inside load_from instead of their call site.
Proposed fix
warnings.warn(
f"Skipping {layer_name} block weights: stage uses HyenaTransformerBlock, "
"which has no compatible Swin attention weights. Blocks remain at their "
- "random initialization."
+ "random initialization.",
+ stacklevel=2,
)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| warnings.warn( | |
| f"Skipping {layer_name} block weights: stage uses HyenaTransformerBlock, " | |
| "which has no compatible Swin attention weights. Blocks remain at their " | |
| "random initialization." | |
| ) | |
| warnings.warn( | |
| f"Skipping {layer_name} block weights: stage uses HyenaTransformerBlock, " | |
| "which has no compatible Swin attention weights. Blocks remain at their " | |
| "random initialization.", | |
| stacklevel=2, | |
| ) |
🧰 Tools
🪛 Ruff (0.15.18)
[warning] 363-363: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@monai/networks/nets/swin_unetr.py` around lines 363 - 367, The warning
emitted in `load_from` for `HyenaTransformerBlock` should set an explicit
stacklevel so the caller sees the warning at their call site instead of inside
this helper. Update the existing warnings.warn call in `swin_unetr.py` to
include a stacklevel argument, keeping the same message and using the
`load_from` path where `layer_name` and `HyenaTransformerBlock` are handled.
Source: Linters/SAST tools
| self.use_hyena = use_hyena | ||
| if use_hyena: | ||
| self.blocks = nn.ModuleList( | ||
| [ | ||
| HyenaTransformerBlock( | ||
| dim=dim, | ||
| spatial_dims=len(self.window_size), | ||
| mlp_ratio=mlp_ratio, | ||
| drop=drop, | ||
| drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, | ||
| norm_layer=norm_layer, | ||
| use_checkpoint=use_checkpoint, | ||
| hyena_kernel_size=hyena_kernel_size, | ||
| hyena_kernel_mlp_dim=hyena_kernel_mlp_dim, | ||
| hyena_kernel_layers=hyena_kernel_layers, | ||
| hyena_mask_max_attenuation=hyena_mask_max_attenuation, | ||
| hyena_fft_padding=hyena_fft_padding, | ||
| hyena_grid_type=hyena_grid_type, | ||
| hyena_use_chunked_fft=hyena_use_chunked_fft, | ||
| hyena_use_fft_short_conv=hyena_use_fft_short_conv, | ||
| hyena_omega_0=hyena_omega_0, | ||
| hyena_l_cache=hyena_l_cache, | ||
| hyena_short_conv_fft_chunks=hyena_short_conv_fft_chunks, | ||
| ) | ||
| for i in range(depth) | ||
| ] | ||
| ) | ||
| else: | ||
| self.blocks = nn.ModuleList( | ||
| [ | ||
| SwinTransformerBlock( | ||
| dim=dim, | ||
| num_heads=num_heads, | ||
| window_size=self.window_size, | ||
| shift_size=self.no_shift if (i % 2 == 0) else self.shift_size, | ||
| mlp_ratio=mlp_ratio, | ||
| qkv_bias=qkv_bias, | ||
| drop=drop, | ||
| attn_drop=attn_drop, | ||
| drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, | ||
| norm_layer=norm_layer, | ||
| use_checkpoint=use_checkpoint, | ||
| ) | ||
| for i in range(depth) | ||
| ] | ||
| ) |
There was a problem hiding this comment.
🚀 Performance & Scalability | 🟠 Major | ⚡ Quick win
Skip attention-mask construction for Hyena stages.
self.use_hyena is set here, but forward() still builds compute_mask(...); HyenaTransformerBlock.forward ignores mask_matrix, so Hyena stages keep paying attention-mask memory/latency.
Proposed fix
if len(x_shape) == 5:
b, c, d, h, w = x_shape
- window_size, shift_size = get_window_size((d, h, w), self.window_size, self.shift_size)
x = rearrange(x, "b c d h w -> b d h w c")
- dp = int(np.ceil(d / window_size[0])) * window_size[0]
- hp = int(np.ceil(h / window_size[1])) * window_size[1]
- wp = int(np.ceil(w / window_size[2])) * window_size[2]
- attn_mask = compute_mask([dp, hp, wp], window_size, shift_size, x.device)
+ attn_mask = None
+ if not self.use_hyena:
+ window_size, shift_size = get_window_size((d, h, w), self.window_size, self.shift_size)
+ dp = int(np.ceil(d / window_size[0])) * window_size[0]
+ hp = int(np.ceil(h / window_size[1])) * window_size[1]
+ wp = int(np.ceil(w / window_size[2])) * window_size[2]
+ attn_mask = compute_mask([dp, hp, wp], window_size, shift_size, x.device)
for blk in self.blocks:
x = blk(x, attn_mask)
@@
elif len(x_shape) == 4:
b, c, h, w = x_shape
- window_size, shift_size = get_window_size((h, w), self.window_size, self.shift_size)
x = rearrange(x, "b c h w -> b h w c")
- hp = int(np.ceil(h / window_size[0])) * window_size[0]
- wp = int(np.ceil(w / window_size[1])) * window_size[1]
- attn_mask = compute_mask([hp, wp], window_size, shift_size, x.device)
+ attn_mask = None
+ if not self.use_hyena:
+ window_size, shift_size = get_window_size((h, w), self.window_size, self.shift_size)
+ hp = int(np.ceil(h / window_size[0])) * window_size[0]
+ wp = int(np.ceil(w / window_size[1])) * window_size[1]
+ attn_mask = compute_mask([hp, wp], window_size, shift_size, x.device)
for blk in self.blocks:
x = blk(x, attn_mask)🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@monai/networks/nets/swin_unetr.py` around lines 956 - 1001, The Hyena path
still pays the cost of attention-mask construction even though
HyenaTransformerBlock.forward ignores mask_matrix. Update the stage forward
logic that uses self.use_hyena and compute_mask(...) so mask generation is
skipped entirely for Hyena stages, while preserving the existing
SwinTransformerBlock path for non-Hyena stages. Keep the change localized to the
stage module that initializes self.blocks and dispatches to the block forward
calls.
| def test_duplicate_hyena_stages_kwarg_rejected(self): | ||
| with self.assertRaisesRegex(TypeError, "hyena_stages"): | ||
| HyenaNDUNETR( | ||
| in_channels=1, | ||
| out_channels=14, | ||
| feature_size=12, | ||
| hyena_stages=(True, True, False, False), | ||
| **{"hyena_stages": (True, False, True, False)}, # type: ignore[arg-type] | ||
| ) |
There was a problem hiding this comment.
🎯 Functional Correctness | 🟡 Minor | ⚡ Quick win
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
set -euo pipefail
printf '\n== test file slice ==\n'
sed -n '1,180p' tests/networks/nets/test_hyena_nd_unetr.py
printf '\n== implementation slice ==\n'
sed -n '1,240p' monai/networks/nets/hyena_nd_unetr.py
printf '\n== search for related constructor/tests ==\n'
rg -n "hyena_stages|HAS_NVSUBQ|ImportError|TypeError" tests/networks/nets/test_hyena_nd_unetr.py monai/networks/nets/hyena_nd_unetr.pyRepository: Project-MONAI/MONAI
Length of output: 20034
Remove the unreachable hyena_stages kwarg check. Passing hyena_stages both explicitly and via **{...} raises TypeError before HyenaNDUNETR.__init__ runs, so this test only hits call-site parsing and the if "hyena_stages" in kwargs: branch can never fire.
🧰 Tools
🪛 Ruff (0.15.18)
[error] 83-83: Repeated keyword argument: hyena_stages
(PLE1132)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@tests/networks/nets/test_hyena_nd_unetr.py` around lines 76 - 84, The
duplicate hyena_stages kwarg assertion in
test_duplicate_hyena_stages_kwarg_rejected is unreachable because Python raises
TypeError at the call site before HyenaNDUNETR.__init__ can inspect kwargs.
Remove this test or rewrite it to cover a reachable duplicate-argument path, and
if keeping a kwargs validation check in HyenaNDUNETR.__init__, add a separate
test that passes hyena_stages only through kwargs so the branch can actually
execute.
Source: Linters/SAST tools
| @skip_if_no_cuda | ||
| def test_default_path_unchanged(self): | ||
| """SwinUNETR with no hyena kwargs produces the pre-port golden output.""" | ||
| import hashlib | ||
|
|
||
| torch.manual_seed(0) | ||
| torch.cuda.manual_seed_all(0) | ||
| net = SwinUNETR(in_channels=1, out_channels=14, feature_size=12).cuda().eval() | ||
| torch.manual_seed(1) | ||
| x = torch.randn(1, 1, 64, 64, 64, device="cuda") | ||
| with torch.no_grad(): | ||
| out = net(x) | ||
| h = hashlib.sha256(out.flatten()[:64].cpu().numpy().tobytes()).hexdigest() | ||
| self.assertEqual( | ||
| h, | ||
| HYENA_BACKCOMPAT_GOLDEN_HASH, | ||
| "Default SwinUNETR forward output drifted; HyenaND port changed the non-Hyena path.", | ||
| ) |
There was a problem hiding this comment.
🩺 Stability & Availability | 🟠 Major | ⚡ Quick win
🧩 Analysis chain
🏁 Script executed:
ast-grep outline tests/networks/nets/test_swin_unetr.py --view expandedRepository: Project-MONAI/MONAI
Length of output: 1974
🏁 Script executed:
sed -n '1,260p' tests/networks/nets/test_swin_unetr.pyRepository: Project-MONAI/MONAI
Length of output: 10596
🏁 Script executed:
python3 - <<'PY'
from pathlib import Path
p = Path('tests/networks/nets/test_swin_unetr.py')
text = p.read_text()
for name in ['HYENA_BACKCOMPAT_GOLDEN_HASH', 'skip_if_no_cuda']:
print(f'## {name}')
for i, line in enumerate(text.splitlines(), 1):
if name in line:
print(i, line)
PYRepository: Project-MONAI/MONAI
Length of output: 505
🏁 Script executed:
git diff --stat -- tests/networks/nets/test_swin_unetr.pyRepository: Project-MONAI/MONAI
Length of output: 157
🏁 Script executed:
rg -n "sha256|golden|bit-identical|assert_close|flatten\(\)\[:64\]" tests monai -g '*.py'Repository: Project-MONAI/MONAI
Length of output: 5305
Avoid hashing raw CUDA bytes here. This golden check will fail on harmless GPU/CUDA/PyTorch numeric drift even if the non-Hyena path is still correct. Compare against a stored reference tensor with torch.testing.assert_close(...) or another tolerance-based invariant instead.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@tests/networks/nets/test_swin_unetr.py` around lines 170 - 187, The golden
check in test_default_path_unchanged is too strict because it hashes raw CUDA
bytes, so replace the SHA256 comparison with a tolerance-based validation such
as torch.testing.assert_close against a stored reference tensor or another
numeric invariant. Keep the test focused on the SwinUNETR default forward path
and preserve the existing setup in test_default_path_unchanged, but avoid
byte-level equality that can fail from harmless GPU/PyTorch drift.
SwinUNETR + HyenaND: subquadratic alternative to windowed self-attention
Upstreams the four-variant PanTS matrix from NeurIPS 2026 paper id 26539
(Native Multi-Dimensional Subquadratic Operators via Input Dependent Long
Convolutions). HyenaND replaces windowed self-attention with a gated long
convolution backed by FFT global receptive field at O(N log N) cost instead
of attention's O(N²)-within-a-window.
Public surface
monai.networks.blocks:HyenaMixer,HyenaTransformerBlock,DepthwiseFFTConv{2,3}d.monai.networks.nets.SwinUNETR: newuse_hyena/hyena_stages/hyena_*kwargs threaded throughSwinTransformer→BasicLayer.monai.networks.nets.HyenaNDUNETR: thinSwinUNETRsubclass withfrom_paper_variant("HHHH" | "HAHA" | "HHAA").[hyena]extras_require →pip install 'monai[hyena]'(
nvsubquadratic0.1.0 on PyPI).nvsubquadratic is gated through optional_import; SwinUNETR(use_hyena=False) never imports it.
Tests
72 new tests across test_hyena_block.py (40), test_swin_unetr.py (15 Hyena classes), test_hyena_nd_unetr.py (17).
Types of changes
./runtests.sh -f -u --net --coverage../runtests.sh --quick --unittests --disttests.make htmlcommand in thedocs/folder.