Skip to content

SwinUNETR + HyenaND: subquadratic alternative to windowed self-attention#8958

Open
farhadrgh wants to merge 8 commits into
Project-MONAI:devfrom
farhadrgh:farhadr/hyena
Open

SwinUNETR + HyenaND: subquadratic alternative to windowed self-attention#8958
farhadrgh wants to merge 8 commits into
Project-MONAI:devfrom
farhadrgh:farhadr/hyena

Conversation

@farhadrgh

Copy link
Copy Markdown

SwinUNETR + HyenaND: subquadratic alternative to windowed self-attention

Upstreams the four-variant PanTS matrix from NeurIPS 2026 paper id 26539
(Native Multi-Dimensional Subquadratic Operators via Input Dependent Long
Convolutions
). HyenaND replaces windowed self-attention with a gated long
convolution backed by FFT global receptive field at O(N log N) cost instead
of attention's O(N²)-within-a-window.

Public surface

  • monai.networks.blocks: HyenaMixer, HyenaTransformerBlock,
    DepthwiseFFTConv{2,3}d.
  • monai.networks.nets.SwinUNETR: new use_hyena / hyena_stages /
    hyena_* kwargs threaded through SwinTransformerBasicLayer.
  • monai.networks.nets.HyenaNDUNETR: thin SwinUNETR subclass with
    from_paper_variant("HHHH" | "HAHA" | "HHAA").
  • New [hyena] extras_require → pip install 'monai[hyena]'
    (nvsubquadratic 0.1.0 on PyPI).

nvsubquadratic is gated through optional_import; SwinUNETR(use_hyena=False) never imports it.

from monai.networks.nets import HyenaNDUNETR
net = HyenaNDUNETR.get_variant("HHAA", in_channels=1, out_channels=29, feature_size=48)

Tests

72 new tests across test_hyena_block.py (40), test_swin_unetr.py (15 Hyena classes), test_hyena_nd_unetr.py (17).

Types of changes

  • Non-breaking change (fix or new feature that would not break existing functionality).
  • Breaking change (fix or new feature that would cause existing functionality to change).
  • New tests added to cover the changes.
  • Integration tests passed locally by running ./runtests.sh -f -u --net --coverage.
  • Quick tests passed locally by running ./runtests.sh --quick --unittests --disttests.
  • In-line docstrings updated.
  • Documentation updated, tested make html command in the docs/ folder.

farhadrgh added 6 commits May 29, 2026 13:40
Signed-off-by: Farhad Ramezanghorbani <farhadr@nvidia.com>
Signed-off-by: Farhad Ramezanghorbani <farhadr@nvidia.com>
Signed-off-by: Farhad Ramezanghorbani <farhadr@nvidia.com>
Signed-off-by: Farhad Ramezanghorbani <farhadr@nvidia.com>
Signed-off-by: Farhad Ramezanghorbani <farhadr@nvidia.com>
Signed-off-by: Farhad Ramezanghorbani <farhadr@nvidia.com>
@coderabbitai

coderabbitai Bot commented Jun 29, 2026

Copy link
Copy Markdown
Contributor
📝 Walkthrough

Walkthrough

Adds HyenaND-based sequence modeling blocks to MONAI: DepthwiseFFTConv2d/3d (FFT-based depthwise conv avoiding INT32 constraints), HyenaMixer (wrapping optional nvsubquadratic HyenaND with QKV projection and float32 FFT path), and HyenaTransformerBlock (pre-norm residual block). These are threaded into SwinUNETR/SwinTransformer/BasicLayer via use_hyena and hyena_stages parameters with per-stage selection and RoPE divisibility validation. load_from skips Hyena stages with a warning. HyenaNDUNETR subclasses SwinUNETR with enforced stage configs and get_variant for three paper presets. Accompanied by full test coverage, packaging (monai[hyena]), and docs.

Estimated code review effort

🎯 5 (Critical) | ⏱️ ~120 minutes

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 18.29% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly summarizes the main change: adding HyenaND as a subquadratic alternative for SwinUNETR self-attention.
Description check ✅ Passed The description is mostly complete, with clear change summary, tests, and change-type checklist; only the issue reference is omitted.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands.

@farhadrgh farhadrgh changed the title Farhadr/hyena SwinUNETR + HyenaND: subquadratic alternative to windowed self-attention Jun 29, 2026

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 7

🧹 Nitpick comments (6)
docs/source/networks.rst (1)

132-148: 📐 Maintainability & Code Quality | 🔵 Trivial | ⚡ Quick win

Consider documenting is_nvsubquadratic_available().

The function is public API (__all__) and lets users detect the optional dependency at runtime. Add an autofunction entry near the Hyena block classes:

`Hyena Utilities`
~~~~~~~~~~~~~~~~~
.. autofunction:: monai.networks.blocks.hyena.is_nvsubquadratic_available
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@docs/source/networks.rst` around lines 132 - 148, Add documentation for the
public helper is_nvsubquadratic_available() in the Hyena networks docs, since it
is part of the exposed API and users need a runtime way to detect the optional
dependency. Update the Hyena section in networks.rst near HyenaMixer and
HyenaTransformerBlock by adding a Hyena Utilities subsection with an
autofunction entry for monai.networks.blocks.hyena.is_nvsubquadratic_available,
keeping it grouped with the other Hyena-related APIs.
CHANGELOG.md (1)

7-11: 📐 Maintainability & Code Quality | 🔵 Trivial | ⚡ Quick win

Consider adding is_nvsubquadratic_available() and load_from skip behavior to the changelog.

The entry is accurate but omits two user-visible behaviors:

  • monai.networks.blocks.hyena.is_nvsubquadratic_available() for runtime optional-dependency detection.
  • SwinUNETR.load_from now skips Hyena stages with a warning rather than failing.

Both are worth noting for users integrating Hyena conditionally.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@CHANGELOG.md` around lines 7 - 11, The Hyena changelog entry is missing two
user-visible behaviors that should be called out for users integrating the
optional dependency conditionally. Update the `CHANGELOG.md` “Added” section to
mention `monai.networks.blocks.hyena.is_nvsubquadratic_available()` as the
runtime availability check, and note that `SwinUNETR.load_from` now skips Hyena
stages with a warning instead of failing. Keep the wording concise and tie both
items to the existing Hyena additions so readers can find the relevant APIs
easily.
monai/networks/blocks/hyena.py (1)

95-135: 📐 Maintainability & Code Quality | 🔵 Trivial | ⚡ Quick win

Add the required Google-style docstrings.

Several new definitions lack docstrings or Args/Returns/Raises sections. As per path instructions, “Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings.”

Also applies to: 149-177, 189-217, 389-413, 517-548

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@monai/networks/blocks/hyena.py` around lines 95 - 135, Add Google-style
docstrings to the newly introduced definitions in hyena.py, especially the
forward methods and related helpers referenced by the diff, so each
function/class clearly documents its purpose, all parameters/variables, return
value, and any raised exceptions. Update the affected symbols in the Hyena block
implementations and any other new definitions called out in the review to
include the required Args and Returns/Raises sections, matching the repository’s
docstring conventions.

Source: Path instructions

tests/networks/blocks/test_hyena_block.py (1)

327-337: 📐 Maintainability & Code Quality | 🔵 Trivial | ⚡ Quick win

Assert the FFT short-conv wiring.

These tests still pass if HyenaMixer ignores use_fft_short_conv or short_conv_fft_chunk_size, because plain nn.Conv3d preserves the same shape. Add a type/config assertion on the constructed short-conv module so the branch in monai/networks/blocks/hyena.py:274-368 is actually covered. As per path instructions, Ensure new or modified definitions will be covered by existing or new unit tests.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/networks/blocks/test_hyena_block.py` around lines 327 - 337, The
HyenaMixer 3D tests only verify output shape, so they still pass even if the FFT
short-conv path is never used. Update the tests around HyenaMixer construction
to assert the short-conv module type/config when use_fft_short_conv and
short_conv_fft_chunk_size are set, so the branch in HyenaMixer’s short-conv
wiring is explicitly exercised and covered by unit tests.

Source: Path instructions

tests/networks/nets/test_hyena_nd_unetr.py (1)

40-95: 📐 Maintainability & Code Quality | 🔵 Trivial | ⚡ Quick win

Add the missing no-dependency constructor test.

Every constructor-contract case is skipped when nvsubquadratic is absent, but HyenaNDUNETR documents an ImportError path in monai/networks/nets/hyena_nd_unetr.py:59-165. Add a @skipUnless(not HAS_NVSUBQ, ...) case so that contract stays covered. As per path instructions, Ensure new or modified definitions will be covered by existing or new unit tests.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/networks/nets/test_hyena_nd_unetr.py` around lines 40 - 95, Add a
missing constructor-contract test for the no-dependency path in
TestHyenaNDUNETRConstructorContract so coverage still exists when HAS_NVSUBQ is
false. Create a new test alongside the existing HyenaNDUNETR constructor tests
that is skipped unless nvsubquadratic is absent and asserts the documented
ImportError behavior from HyenaNDUNETR/__init__. Keep the focus on the
constructor symbols HyenaNDUNETR and TestHyenaNDUNETRConstructorContract, and
ensure this path is exercised by a unit test rather than only the
dependency-present cases.

Source: Path instructions

monai/networks/nets/swin_unetr.py (1)

331-340: 📐 Maintainability & Code Quality | 🔵 Trivial | ⚡ Quick win

Add Google-style sections to load_from.

This public method now has custom checkpoint-loading behavior, but weights, return value, and expected exceptions are not documented.

Proposed docstring update
     def load_from(self, weights):
         """Load pretrained Swin weights into the matching submodules.
 
         When a stage uses :class:`HyenaTransformerBlock` instead of
         :class:`SwinTransformerBlock`, the per-block ``load_from`` call is skipped for
         that stage and a warning is issued -- HyenaND has a different parameter layout
         and there are no compatible attention weights to copy. PatchMerging
         downsample weights are still loaded for all stages (the downsample layer is
         the same in both code paths).
+
+        Args:
+            weights: Checkpoint mapping containing a ``"state_dict"`` with Swin
+                pretrained parameter tensors.
+
+        Returns:
+            None.
+
+        Raises:
+            KeyError: If an expected checkpoint key is absent.
+            RuntimeError: If a checkpoint tensor shape is incompatible.
         """

As per path instructions, “Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings.”

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@monai/networks/nets/swin_unetr.py` around lines 331 - 340, The public method
load_from in SwinUNETR needs a full Google-style docstring update because it now
has custom checkpoint-loading behavior without documented parameters, return
value, or exceptions. Expand the existing docstring to add Args for weights (and
any other inputs used by load_from), a Returns section if it returns a value,
and a Raises section for any expected exceptions, while keeping the current
summary about Swin and HyenaTransformerBlock loading behavior.

Source: Path instructions

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@docs/source/installation.md`:
- Around line 263-266: Update the `hyena` extra documentation to point the
`nvsubquadratic` hyperlink at the correct NVIDIA-BioNeMo repository instead of
the dead NVIDIA URL. Locate the `hyena` installation description in the docs and
replace the existing link target so readers using `HyenaNDUNETR`, `HyenaMixer`,
and `HyenaTransformerBlock` are directed to the valid project page.

In `@monai/networks/blocks/hyena.py`:
- Around line 51-56: The availability check in the Hyena module currently
depends only on the LazyConfig import, so it can report nvsubquadratic as
available even when instantiate, Hyena, CKConvND, SIRENKernelND, or
GaussianModulationND fail to import. Update the import setup in hyena.py so each
optional_import contributes a flag, and change is_nvsubquadratic_available() to
require all of those symbols to be present before returning true.
- Around line 95-135: `Hyena.forward` currently hardcodes the FFT crop around
`kernel_shape // 2`, so `DepthwiseFFTConv{2,3}d` ignores its `self.padding`
setting and can produce incorrect outputs for non-default padding or even
kernels. Update the crop logic in `forward()` (including the chunked path) to
use the module’s padding value when building `slices`, and keep the FFT result
aligned with Conv{2,3}d semantics for both output shape and values.

In `@monai/networks/nets/swin_unetr.py`:
- Around line 363-367: The warning emitted in `load_from` for
`HyenaTransformerBlock` should set an explicit stacklevel so the caller sees the
warning at their call site instead of inside this helper. Update the existing
warnings.warn call in `swin_unetr.py` to include a stacklevel argument, keeping
the same message and using the `load_from` path where `layer_name` and
`HyenaTransformerBlock` are handled.
- Around line 956-1001: The Hyena path still pays the cost of attention-mask
construction even though HyenaTransformerBlock.forward ignores mask_matrix.
Update the stage forward logic that uses self.use_hyena and compute_mask(...) so
mask generation is skipped entirely for Hyena stages, while preserving the
existing SwinTransformerBlock path for non-Hyena stages. Keep the change
localized to the stage module that initializes self.blocks and dispatches to the
block forward calls.

In `@tests/networks/nets/test_hyena_nd_unetr.py`:
- Around line 76-84: The duplicate hyena_stages kwarg assertion in
test_duplicate_hyena_stages_kwarg_rejected is unreachable because Python raises
TypeError at the call site before HyenaNDUNETR.__init__ can inspect kwargs.
Remove this test or rewrite it to cover a reachable duplicate-argument path, and
if keeping a kwargs validation check in HyenaNDUNETR.__init__, add a separate
test that passes hyena_stages only through kwargs so the branch can actually
execute.

In `@tests/networks/nets/test_swin_unetr.py`:
- Around line 170-187: The golden check in test_default_path_unchanged is too
strict because it hashes raw CUDA bytes, so replace the SHA256 comparison with a
tolerance-based validation such as torch.testing.assert_close against a stored
reference tensor or another numeric invariant. Keep the test focused on the
SwinUNETR default forward path and preserve the existing setup in
test_default_path_unchanged, but avoid byte-level equality that can fail from
harmless GPU/PyTorch drift.

---

Nitpick comments:
In `@CHANGELOG.md`:
- Around line 7-11: The Hyena changelog entry is missing two user-visible
behaviors that should be called out for users integrating the optional
dependency conditionally. Update the `CHANGELOG.md` “Added” section to mention
`monai.networks.blocks.hyena.is_nvsubquadratic_available()` as the runtime
availability check, and note that `SwinUNETR.load_from` now skips Hyena stages
with a warning instead of failing. Keep the wording concise and tie both items
to the existing Hyena additions so readers can find the relevant APIs easily.

In `@docs/source/networks.rst`:
- Around line 132-148: Add documentation for the public helper
is_nvsubquadratic_available() in the Hyena networks docs, since it is part of
the exposed API and users need a runtime way to detect the optional dependency.
Update the Hyena section in networks.rst near HyenaMixer and
HyenaTransformerBlock by adding a Hyena Utilities subsection with an
autofunction entry for monai.networks.blocks.hyena.is_nvsubquadratic_available,
keeping it grouped with the other Hyena-related APIs.

In `@monai/networks/blocks/hyena.py`:
- Around line 95-135: Add Google-style docstrings to the newly introduced
definitions in hyena.py, especially the forward methods and related helpers
referenced by the diff, so each function/class clearly documents its purpose,
all parameters/variables, return value, and any raised exceptions. Update the
affected symbols in the Hyena block implementations and any other new
definitions called out in the review to include the required Args and
Returns/Raises sections, matching the repository’s docstring conventions.

In `@monai/networks/nets/swin_unetr.py`:
- Around line 331-340: The public method load_from in SwinUNETR needs a full
Google-style docstring update because it now has custom checkpoint-loading
behavior without documented parameters, return value, or exceptions. Expand the
existing docstring to add Args for weights (and any other inputs used by
load_from), a Returns section if it returns a value, and a Raises section for
any expected exceptions, while keeping the current summary about Swin and
HyenaTransformerBlock loading behavior.

In `@tests/networks/blocks/test_hyena_block.py`:
- Around line 327-337: The HyenaMixer 3D tests only verify output shape, so they
still pass even if the FFT short-conv path is never used. Update the tests
around HyenaMixer construction to assert the short-conv module type/config when
use_fft_short_conv and short_conv_fft_chunk_size are set, so the branch in
HyenaMixer’s short-conv wiring is explicitly exercised and covered by unit
tests.

In `@tests/networks/nets/test_hyena_nd_unetr.py`:
- Around line 40-95: Add a missing constructor-contract test for the
no-dependency path in TestHyenaNDUNETRConstructorContract so coverage still
exists when HAS_NVSUBQ is false. Create a new test alongside the existing
HyenaNDUNETR constructor tests that is skipped unless nvsubquadratic is absent
and asserts the documented ImportError behavior from HyenaNDUNETR/__init__. Keep
the focus on the constructor symbols HyenaNDUNETR and
TestHyenaNDUNETRConstructorContract, and ensure this path is exercised by a unit
test rather than only the dependency-present cases.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: cb613c72-ab77-43ad-a849-6b6826f9760c

📥 Commits

Reviewing files that changed from the base of the PR and between b7d14c8 and e910366.

📒 Files selected for processing (14)
  • CHANGELOG.md
  • docs/source/installation.md
  • docs/source/networks.rst
  • monai/networks/blocks/__init__.py
  • monai/networks/blocks/hyena.py
  • monai/networks/nets/__init__.py
  • monai/networks/nets/hyena_nd_unetr.py
  • monai/networks/nets/swin_unetr.py
  • requirements-dev.txt
  • setup.cfg
  • tests/min_tests.py
  • tests/networks/blocks/test_hyena_block.py
  • tests/networks/nets/test_hyena_nd_unetr.py
  • tests/networks/nets/test_swin_unetr.py

Comment on lines +263 to +266
The `hyena` extra pulls in [`nvsubquadratic`](https://github.com/NVIDIA/nvSubquadratic),
required by `HyenaNDUNETR` / `HyenaMixer` / `HyenaTransformerBlock` (subquadratic
O(N log N) alternatives to windowed self-attention). Install with
`pip install 'monai[hyena]'`.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎯 Functional Correctness | 🟡 Minor | ⚡ Quick win

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Verify GitHub repository URL validity and canonical casing
curl -s -o /dev/null -w "%{http_code}" https://github.com/NVIDIA/nvSubquadratic
curl -s -o /dev/null -w "%{http_code}" https://github.com/NVIDIA/nvsubquadratic

Repository: Project-MONAI/MONAI

Length of output: 163


🌐 Web query:

NVIDIA nvSubquadratic GitHub repository canonical name

💡 Result:

The canonical GitHub repository for the NVIDIA library nvSubquadratic is located at https://github.com/NVIDIA-BioNeMo/nvSubquadratic [1][2]. It is important to distinguish this library from other similarly named projects or entities, such as the AI startup "Subquadratic" (which develops the "SubQ" model) [3][4], or other unrelated NVIDIA repositories like "optix-subd" [5] and "subquadratic-ops-torch" [6]. The nvSubquadratic library is a PyTorch-native toolkit developed by NVIDIA for subquadratic alternatives to traditional quadratic attention mechanisms, specifically utilizing optimized CUDA kernels for operators like Hyena [1][2].

Citations:


Fix the nvSubquadratic linkhttps://github.com/NVIDIA/nvSubquadratic is dead; use https://github.com/NVIDIA-BioNeMo/nvSubquadratic instead.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@docs/source/installation.md` around lines 263 - 266, Update the `hyena` extra
documentation to point the `nvsubquadratic` hyperlink at the correct
NVIDIA-BioNeMo repository instead of the dead NVIDIA URL. Locate the `hyena`
installation description in the docs and replace the existing link target so
readers using `HyenaNDUNETR`, `HyenaMixer`, and `HyenaTransformerBlock` are
directed to the valid project page.

Comment on lines +51 to +56
_LazyConfig, _has_nvsubq = optional_import("nvsubquadratic.lazy_config", name="LazyConfig")
_instantiate, _ = optional_import("nvsubquadratic.lazy_config", name="instantiate")
_Hyena, _ = optional_import("nvsubquadratic.modules.hyena_nd", name="Hyena")
_CKConvND, _ = optional_import("nvsubquadratic.modules.ckconv_nd", name="CKConvND")
_SIRENKernelND, _ = optional_import("nvsubquadratic.modules.kernels_nd", name="SIRENKernelND")
_GaussianModulationND, _ = optional_import("nvsubquadratic.modules.masks_nd", name="GaussianModulationND")

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🩺 Stability & Availability | 🟡 Minor | ⚡ Quick win

🧩 Analysis chain

🏁 Script executed:

ast-grep outline monai/networks/blocks/hyena.py --view expanded

Repository: Project-MONAI/MONAI

Length of output: 1102


🏁 Script executed:

sed -n '1,220p' monai/networks/blocks/hyena.py

Repository: Project-MONAI/MONAI

Length of output: 9030


🏁 Script executed:

sed -n '220,380p' monai/networks/blocks/hyena.py

Repository: Project-MONAI/MONAI

Length of output: 7391


🏁 Script executed:

rg -n "is_nvsubquadratic_available\\(" monai

Repository: Project-MONAI/MONAI

Length of output: 234


🏁 Script executed:

sed -n '380,560p' monai/networks/blocks/hyena.py

Repository: Project-MONAI/MONAI

Length of output: 6867


Gate availability on every optional nvsubquadratic import. is_nvsubquadratic_available() only tracks LazyConfig, so it can still return true when instantiate, Hyena, CKConvND, SIRENKernelND, or GaussianModulationND is missing; fold all of those import flags into the helper.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@monai/networks/blocks/hyena.py` around lines 51 - 56, The availability check
in the Hyena module currently depends only on the LazyConfig import, so it can
report nvsubquadratic as available even when instantiate, Hyena, CKConvND,
SIRENKernelND, or GaussianModulationND fail to import. Update the import setup
in hyena.py so each optional_import contributes a flag, and change
is_nvsubquadratic_available() to require all of those symbols to be present
before returning true.

Comment on lines +95 to +135
def forward(self, x: torch.Tensor) -> torch.Tensor:
spatial = x.shape[2:]
kernel_shape = self.weight.shape[2:] # type: ignore[attr-defined]
fft_dims = tuple(range(-self._spatial_dims, 0))
fft_size = [s + k - 1 for s, k in zip(spatial, kernel_shape)]
in_dtype = x.dtype

slices = (slice(None), slice(None)) + tuple(
slice(k // 2, k // 2 + s) for s, k in zip(spatial, kernel_shape)
)

chunk = getattr(self, "fft_chunk_size", 0)
if chunk > 0 and x.shape[1] > chunk:
parts = []
for c0 in range(0, x.shape[1], chunk):
c1 = min(c0 + chunk, x.shape[1])
xc = x[:, c0:c1].float()
kc = self.weight[c0:c1].squeeze(1).float() # type: ignore[attr-defined]
kc = kc.flip(list(range(1, self._spatial_dims + 1)))
xc_fft = torch.fft.rfftn(xc, s=fft_size, dim=fft_dims)
kc_fft = torch.fft.rfftn(kc, s=fft_size, dim=fft_dims)
out_fft = xc_fft * kc_fft.unsqueeze(0)
del xc_fft, kc_fft
out_c = torch.fft.irfftn(out_fft, s=fft_size, dim=fft_dims)
del out_fft
parts.append(out_c[slices].to(in_dtype))
del out_c
return torch.cat(parts, dim=1)

x_f32 = x.float()
k_f32 = self.weight.squeeze(1).float() # type: ignore[attr-defined]
# PyTorch ``F.conv*`` computes cross-correlation; FFT computes convolution.
# Flip the kernel so the FFT output matches ``Conv{2,3}d`` exactly.
k_f32 = k_f32.flip(list(range(1, self._spatial_dims + 1)))

x_fft = torch.fft.rfftn(x_f32, s=fft_size, dim=fft_dims)
k_fft = torch.fft.rfftn(k_f32, s=fft_size, dim=fft_dims)

out_fft = x_fft * k_fft.unsqueeze(0)
out = torch.fft.irfftn(out_fft, s=fft_size, dim=fft_dims)
return out[slices].to(in_dtype)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎯 Functional Correctness | 🟠 Major | ⚡ Quick win

Honor self.padding in the FFT crop.

DepthwiseFFTConv{2,3}d accepts padding, but forward() always crops as if padding == kernel_size // 2. Non-default padding, even kernels, or future callers expecting Conv semantics will return the wrong shape/values.

Possible fix direction
-        slices = (slice(None), slice(None)) + tuple(
-            slice(k // 2, k // 2 + s) for s, k in zip(spatial, kernel_shape)
-        )
+        padding = self.padding  # type: ignore[attr-defined]
+        if isinstance(padding, int):
+            padding = (padding,) * self._spatial_dims
+        out_spatial = [s + (2 * p) - k + 1 for s, p, k in zip(spatial, padding, kernel_shape, strict=True)]
+        slices = (
+            slice(None),
+            slice(None),
+            *(slice(k - 1 - p, k - 1 - p + out) for p, k, out in zip(padding, kernel_shape, out_spatial, strict=True)),
+        )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def forward(self, x: torch.Tensor) -> torch.Tensor:
spatial = x.shape[2:]
kernel_shape = self.weight.shape[2:] # type: ignore[attr-defined]
fft_dims = tuple(range(-self._spatial_dims, 0))
fft_size = [s + k - 1 for s, k in zip(spatial, kernel_shape)]
in_dtype = x.dtype
slices = (slice(None), slice(None)) + tuple(
slice(k // 2, k // 2 + s) for s, k in zip(spatial, kernel_shape)
)
chunk = getattr(self, "fft_chunk_size", 0)
if chunk > 0 and x.shape[1] > chunk:
parts = []
for c0 in range(0, x.shape[1], chunk):
c1 = min(c0 + chunk, x.shape[1])
xc = x[:, c0:c1].float()
kc = self.weight[c0:c1].squeeze(1).float() # type: ignore[attr-defined]
kc = kc.flip(list(range(1, self._spatial_dims + 1)))
xc_fft = torch.fft.rfftn(xc, s=fft_size, dim=fft_dims)
kc_fft = torch.fft.rfftn(kc, s=fft_size, dim=fft_dims)
out_fft = xc_fft * kc_fft.unsqueeze(0)
del xc_fft, kc_fft
out_c = torch.fft.irfftn(out_fft, s=fft_size, dim=fft_dims)
del out_fft
parts.append(out_c[slices].to(in_dtype))
del out_c
return torch.cat(parts, dim=1)
x_f32 = x.float()
k_f32 = self.weight.squeeze(1).float() # type: ignore[attr-defined]
# PyTorch ``F.conv*`` computes cross-correlation; FFT computes convolution.
# Flip the kernel so the FFT output matches ``Conv{2,3}d`` exactly.
k_f32 = k_f32.flip(list(range(1, self._spatial_dims + 1)))
x_fft = torch.fft.rfftn(x_f32, s=fft_size, dim=fft_dims)
k_fft = torch.fft.rfftn(k_f32, s=fft_size, dim=fft_dims)
out_fft = x_fft * k_fft.unsqueeze(0)
out = torch.fft.irfftn(out_fft, s=fft_size, dim=fft_dims)
return out[slices].to(in_dtype)
def forward(self, x: torch.Tensor) -> torch.Tensor:
spatial = x.shape[2:]
kernel_shape = self.weight.shape[2:] # type: ignore[attr-defined]
fft_dims = tuple(range(-self._spatial_dims, 0))
fft_size = [s + k - 1 for s, k in zip(spatial, kernel_shape)]
in_dtype = x.dtype
padding = self.padding # type: ignore[attr-defined]
if isinstance(padding, int):
padding = (padding,) * self._spatial_dims
out_spatial = [s + (2 * p) - k + 1 for s, p, k in zip(spatial, padding, kernel_shape, strict=True)]
slices = (
slice(None),
slice(None),
*(slice(k - 1 - p, k - 1 - p + out) for p, k, out in zip(padding, kernel_shape, out_spatial, strict=True)),
)
chunk = getattr(self, "fft_chunk_size", 0)
if chunk > 0 and x.shape[1] > chunk:
parts = []
for c0 in range(0, x.shape[1], chunk):
c1 = min(c0 + chunk, x.shape[1])
xc = x[:, c0:c1].float()
kc = self.weight[c0:c1].squeeze(1).float() # type: ignore[attr-defined]
kc = kc.flip(list(range(1, self._spatial_dims + 1)))
xc_fft = torch.fft.rfftn(xc, s=fft_size, dim=fft_dims)
kc_fft = torch.fft.rfftn(kc, s=fft_size, dim=fft_dims)
out_fft = xc_fft * kc_fft.unsqueeze(0)
del xc_fft, kc_fft
out_c = torch.fft.irfftn(out_fft, s=fft_size, dim=fft_dims)
del out_fft
parts.append(out_c[slices].to(in_dtype))
del out_c
return torch.cat(parts, dim=1)
x_f32 = x.float()
k_f32 = self.weight.squeeze(1).float() # type: ignore[attr-defined]
# PyTorch ``F.conv*`` computes cross-correlation; FFT computes convolution.
# Flip the kernel so the FFT output matches ``Conv{2,3}d`` exactly.
k_f32 = k_f32.flip(list(range(1, self._spatial_dims + 1)))
x_fft = torch.fft.rfftn(x_f32, s=fft_size, dim=fft_dims)
k_fft = torch.fft.rfftn(k_f32, s=fft_size, dim=fft_dims)
out_fft = x_fft * k_fft.unsqueeze(0)
out = torch.fft.irfftn(out_fft, s=fft_size, dim=fft_dims)
return out[slices].to(in_dtype)
🧰 Tools
🪛 Ruff (0.15.18)

[warning] 99-99: zip() without an explicit strict= parameter

Add explicit value for parameter strict=

(B905)


[warning] 102-104: Consider iterable unpacking instead of concatenation

Replace with iterable unpacking

(RUF005)


[warning] 103-103: zip() without an explicit strict= parameter

Add explicit value for parameter strict=

(B905)

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@monai/networks/blocks/hyena.py` around lines 95 - 135, `Hyena.forward`
currently hardcodes the FFT crop around `kernel_shape // 2`, so
`DepthwiseFFTConv{2,3}d` ignores its `self.padding` setting and can produce
incorrect outputs for non-default padding or even kernels. Update the crop logic
in `forward()` (including the chunked path) to use the module’s padding value
when building `slices`, and keep the FFT result aligned with Conv{2,3}d
semantics for both output shape and values.

Comment on lines +363 to +367
warnings.warn(
f"Skipping {layer_name} block weights: stage uses HyenaTransformerBlock, "
"which has no compatible Swin attention weights. Blocks remain at their "
"random initialization."
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📐 Maintainability & Code Quality | 🟡 Minor | ⚡ Quick win

Set stacklevel on the warning.

Without it, callers see the warning location inside load_from instead of their call site.

Proposed fix
                     warnings.warn(
                         f"Skipping {layer_name} block weights: stage uses HyenaTransformerBlock, "
                         "which has no compatible Swin attention weights. Blocks remain at their "
-                        "random initialization."
+                        "random initialization.",
+                        stacklevel=2,
                     )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
warnings.warn(
f"Skipping {layer_name} block weights: stage uses HyenaTransformerBlock, "
"which has no compatible Swin attention weights. Blocks remain at their "
"random initialization."
)
warnings.warn(
f"Skipping {layer_name} block weights: stage uses HyenaTransformerBlock, "
"which has no compatible Swin attention weights. Blocks remain at their "
"random initialization.",
stacklevel=2,
)
🧰 Tools
🪛 Ruff (0.15.18)

[warning] 363-363: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@monai/networks/nets/swin_unetr.py` around lines 363 - 367, The warning
emitted in `load_from` for `HyenaTransformerBlock` should set an explicit
stacklevel so the caller sees the warning at their call site instead of inside
this helper. Update the existing warnings.warn call in `swin_unetr.py` to
include a stacklevel argument, keeping the same message and using the
`load_from` path where `layer_name` and `HyenaTransformerBlock` are handled.

Source: Linters/SAST tools

Comment on lines +956 to +1001
self.use_hyena = use_hyena
if use_hyena:
self.blocks = nn.ModuleList(
[
HyenaTransformerBlock(
dim=dim,
spatial_dims=len(self.window_size),
mlp_ratio=mlp_ratio,
drop=drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
norm_layer=norm_layer,
use_checkpoint=use_checkpoint,
hyena_kernel_size=hyena_kernel_size,
hyena_kernel_mlp_dim=hyena_kernel_mlp_dim,
hyena_kernel_layers=hyena_kernel_layers,
hyena_mask_max_attenuation=hyena_mask_max_attenuation,
hyena_fft_padding=hyena_fft_padding,
hyena_grid_type=hyena_grid_type,
hyena_use_chunked_fft=hyena_use_chunked_fft,
hyena_use_fft_short_conv=hyena_use_fft_short_conv,
hyena_omega_0=hyena_omega_0,
hyena_l_cache=hyena_l_cache,
hyena_short_conv_fft_chunks=hyena_short_conv_fft_chunks,
)
for i in range(depth)
]
)
else:
self.blocks = nn.ModuleList(
[
SwinTransformerBlock(
dim=dim,
num_heads=num_heads,
window_size=self.window_size,
shift_size=self.no_shift if (i % 2 == 0) else self.shift_size,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop=drop,
attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
norm_layer=norm_layer,
use_checkpoint=use_checkpoint,
)
for i in range(depth)
]
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚀 Performance & Scalability | 🟠 Major | ⚡ Quick win

Skip attention-mask construction for Hyena stages.

self.use_hyena is set here, but forward() still builds compute_mask(...); HyenaTransformerBlock.forward ignores mask_matrix, so Hyena stages keep paying attention-mask memory/latency.

Proposed fix
         if len(x_shape) == 5:
             b, c, d, h, w = x_shape
-            window_size, shift_size = get_window_size((d, h, w), self.window_size, self.shift_size)
             x = rearrange(x, "b c d h w -> b d h w c")
-            dp = int(np.ceil(d / window_size[0])) * window_size[0]
-            hp = int(np.ceil(h / window_size[1])) * window_size[1]
-            wp = int(np.ceil(w / window_size[2])) * window_size[2]
-            attn_mask = compute_mask([dp, hp, wp], window_size, shift_size, x.device)
+            attn_mask = None
+            if not self.use_hyena:
+                window_size, shift_size = get_window_size((d, h, w), self.window_size, self.shift_size)
+                dp = int(np.ceil(d / window_size[0])) * window_size[0]
+                hp = int(np.ceil(h / window_size[1])) * window_size[1]
+                wp = int(np.ceil(w / window_size[2])) * window_size[2]
+                attn_mask = compute_mask([dp, hp, wp], window_size, shift_size, x.device)
             for blk in self.blocks:
                 x = blk(x, attn_mask)
@@
         elif len(x_shape) == 4:
             b, c, h, w = x_shape
-            window_size, shift_size = get_window_size((h, w), self.window_size, self.shift_size)
             x = rearrange(x, "b c h w -> b h w c")
-            hp = int(np.ceil(h / window_size[0])) * window_size[0]
-            wp = int(np.ceil(w / window_size[1])) * window_size[1]
-            attn_mask = compute_mask([hp, wp], window_size, shift_size, x.device)
+            attn_mask = None
+            if not self.use_hyena:
+                window_size, shift_size = get_window_size((h, w), self.window_size, self.shift_size)
+                hp = int(np.ceil(h / window_size[0])) * window_size[0]
+                wp = int(np.ceil(w / window_size[1])) * window_size[1]
+                attn_mask = compute_mask([hp, wp], window_size, shift_size, x.device)
             for blk in self.blocks:
                 x = blk(x, attn_mask)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@monai/networks/nets/swin_unetr.py` around lines 956 - 1001, The Hyena path
still pays the cost of attention-mask construction even though
HyenaTransformerBlock.forward ignores mask_matrix. Update the stage forward
logic that uses self.use_hyena and compute_mask(...) so mask generation is
skipped entirely for Hyena stages, while preserving the existing
SwinTransformerBlock path for non-Hyena stages. Keep the change localized to the
stage module that initializes self.blocks and dispatches to the block forward
calls.

Comment on lines +76 to +84
def test_duplicate_hyena_stages_kwarg_rejected(self):
with self.assertRaisesRegex(TypeError, "hyena_stages"):
HyenaNDUNETR(
in_channels=1,
out_channels=14,
feature_size=12,
hyena_stages=(True, True, False, False),
**{"hyena_stages": (True, False, True, False)}, # type: ignore[arg-type]
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎯 Functional Correctness | 🟡 Minor | ⚡ Quick win

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
set -euo pipefail

printf '\n== test file slice ==\n'
sed -n '1,180p' tests/networks/nets/test_hyena_nd_unetr.py

printf '\n== implementation slice ==\n'
sed -n '1,240p' monai/networks/nets/hyena_nd_unetr.py

printf '\n== search for related constructor/tests ==\n'
rg -n "hyena_stages|HAS_NVSUBQ|ImportError|TypeError" tests/networks/nets/test_hyena_nd_unetr.py monai/networks/nets/hyena_nd_unetr.py

Repository: Project-MONAI/MONAI

Length of output: 20034


Remove the unreachable hyena_stages kwarg check. Passing hyena_stages both explicitly and via **{...} raises TypeError before HyenaNDUNETR.__init__ runs, so this test only hits call-site parsing and the if "hyena_stages" in kwargs: branch can never fire.

🧰 Tools
🪛 Ruff (0.15.18)

[error] 83-83: Repeated keyword argument: hyena_stages

(PLE1132)

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/networks/nets/test_hyena_nd_unetr.py` around lines 76 - 84, The
duplicate hyena_stages kwarg assertion in
test_duplicate_hyena_stages_kwarg_rejected is unreachable because Python raises
TypeError at the call site before HyenaNDUNETR.__init__ can inspect kwargs.
Remove this test or rewrite it to cover a reachable duplicate-argument path, and
if keeping a kwargs validation check in HyenaNDUNETR.__init__, add a separate
test that passes hyena_stages only through kwargs so the branch can actually
execute.

Source: Linters/SAST tools

Comment on lines +170 to +187
@skip_if_no_cuda
def test_default_path_unchanged(self):
"""SwinUNETR with no hyena kwargs produces the pre-port golden output."""
import hashlib

torch.manual_seed(0)
torch.cuda.manual_seed_all(0)
net = SwinUNETR(in_channels=1, out_channels=14, feature_size=12).cuda().eval()
torch.manual_seed(1)
x = torch.randn(1, 1, 64, 64, 64, device="cuda")
with torch.no_grad():
out = net(x)
h = hashlib.sha256(out.flatten()[:64].cpu().numpy().tobytes()).hexdigest()
self.assertEqual(
h,
HYENA_BACKCOMPAT_GOLDEN_HASH,
"Default SwinUNETR forward output drifted; HyenaND port changed the non-Hyena path.",
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🩺 Stability & Availability | 🟠 Major | ⚡ Quick win

🧩 Analysis chain

🏁 Script executed:

ast-grep outline tests/networks/nets/test_swin_unetr.py --view expanded

Repository: Project-MONAI/MONAI

Length of output: 1974


🏁 Script executed:

sed -n '1,260p' tests/networks/nets/test_swin_unetr.py

Repository: Project-MONAI/MONAI

Length of output: 10596


🏁 Script executed:

python3 - <<'PY'
from pathlib import Path
p = Path('tests/networks/nets/test_swin_unetr.py')
text = p.read_text()
for name in ['HYENA_BACKCOMPAT_GOLDEN_HASH', 'skip_if_no_cuda']:
    print(f'## {name}')
    for i, line in enumerate(text.splitlines(), 1):
        if name in line:
            print(i, line)
PY

Repository: Project-MONAI/MONAI

Length of output: 505


🏁 Script executed:

git diff --stat -- tests/networks/nets/test_swin_unetr.py

Repository: Project-MONAI/MONAI

Length of output: 157


🏁 Script executed:

rg -n "sha256|golden|bit-identical|assert_close|flatten\(\)\[:64\]" tests monai -g '*.py'

Repository: Project-MONAI/MONAI

Length of output: 5305


Avoid hashing raw CUDA bytes here. This golden check will fail on harmless GPU/CUDA/PyTorch numeric drift even if the non-Hyena path is still correct. Compare against a stored reference tensor with torch.testing.assert_close(...) or another tolerance-based invariant instead.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/networks/nets/test_swin_unetr.py` around lines 170 - 187, The golden
check in test_default_path_unchanged is too strict because it hashes raw CUDA
bytes, so replace the SHA256 comparison with a tolerance-based validation such
as torch.testing.assert_close against a stored reference tensor or another
numeric invariant. Keep the test focused on the SwinUNETR default forward path
and preserve the existing setup in test_default_path_unchanged, but avoid
byte-level equality that can fail from harmless GPU/PyTorch drift.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant