diff --git a/.github/workflows/cicd_tests.yml b/.github/workflows/cicd_tests.yml index ae3694f276..10da9db132 100644 --- a/.github/workflows/cicd_tests.yml +++ b/.github/workflows/cicd_tests.yml @@ -237,6 +237,55 @@ jobs: python -m unittest -v shell: bash + hyena-dep: # Optional HyenaND dependency + the no-CUDA Hyena tests. + # nvsubquadratic >= 0.1.1 supports Python >= 3.10 and keeps its CUDA-kernel sdist + # (subquadratic-ops-torch-cu12) plus the megatron / dali / timm packages in opt-in + # extras, so it installs on a CPU runner. We still pass ``--no-deps`` deliberately: + # (1) the HyenaND operators import only torch + einops + omegaconf at runtime, so + # skipping the (still batteries-included: datasets/lightning/wandb) core deps + # keeps this job lean; and + # (2) nvsubquadratic pins torch>=2.10,<2.11, which would otherwise upgrade/clash + # with the torch this job (and MONAI's matrix) installs. + # CUDA-required Hyena tests skip cleanly here; the GPU surface is covered by + # ``.github/workflows/pythonapp-hyena-gpu.yml``. + runs-on: ubuntu-latest + timeout-minutes: 30 + steps: + - name: Clean unused tools + run: | + find /opt/hostedtoolcache/* -maxdepth 0 ! -name 'Python' -exec rm -rf {} \; + sudo rm -rf /usr/share/dotnet + sudo rm -rf /usr/local/lib/android + sudo rm -rf /opt/ghc /usr/local/.ghcup + sudo docker system prune -f + - uses: actions/checkout@v6 + - name: Set up Python ${{ env.PYTHON_VER1 }} + uses: actions/setup-python@v6 + with: + python-version: ${{ env.PYTHON_VER1 }} + cache: 'pip' + - name: Install dependencies + nvsubquadratic (no-deps) + run: | + python -m pip install --upgrade pip wheel + python -m pip install torch==${PYTORCH_VER1} torchvision==${TORCHVISION_VER1} + python -m pip install --no-build-isolation -r requirements-dev.txt + python -m pip install -e . + # nvsubquadratic runtime imports need only torch + einops + omegaconf; install + # the package itself without its core dependency tree (see job comment above). + python -m pip install omegaconf + python -m pip install --no-deps 'nvsubquadratic>=0.1.1' + python -m pip list + shell: bash + - name: Run Hyena tests (CUDA-required cases skip cleanly) + run: | + python -c "from monai.networks.blocks.hyena import is_nvsubquadratic_available; \ + assert is_nvsubquadratic_available(), 'nvsubquadratic must be importable'" + python -m pytest -v \ + tests/networks/blocks/test_hyena_block.py \ + tests/networks/nets/test_hyena_nd_unetr.py \ + tests/networks/nets/test_swin_unetr.py + shell: bash + packaging: # Test package generation runs-on: ubuntu-latest env: diff --git a/.github/workflows/pythonapp-hyena-gpu.yml b/.github/workflows/pythonapp-hyena-gpu.yml new file mode 100644 index 0000000000..af49cb292b --- /dev/null +++ b/.github/workflows/pythonapp-hyena-gpu.yml @@ -0,0 +1,74 @@ +# Optional self-hosted GPU CI for the HyenaND test surface. +# +# This workflow exercises the CUDA-required Hyena tests +# (tests/networks/blocks/test_hyena_block.py CUDA cases, the four-paper-variant +# forward and gradient cases in tests/networks/nets/test_swin_unetr.py and +# tests/networks/nets/test_hyena_nd_unetr.py, the SwinUNETR(use_hyena=False) +# golden-hash backward-compat regression, and sliding-window inference). +# +# Disabled by default (``if: false``). To enable: +# 1. Ensure a self-hosted runner with the labels below is available, AND +# 2. Ensure the runner has CUDA-capable hardware visible (the existing +# ``pythonapp-gpu.yml`` uses ``--gpus all`` against ``[self-hosted, linux, +# x64, common]``). Reuse that pool if possible. +# 3. Flip ``if: false`` to ``if: github.event.pull_request.merged != true`` +# (mirroring ``pythonapp-gpu.yml``'s gating pattern). +# +# nvsubquadratic (Hyena's optional dep) requires Python >= 3.10; any NGC base with +# Python >= 3.10 works. The accelerated [cuda] kernels build against the container nvcc. + +name: hyena-gpu + +on: + workflow_dispatch: + +concurrency: + group: hyena-gpu-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + +jobs: + GPU-Hyena: + if: ${{ false }} # See header for enable instructions. + strategy: + matrix: + environment: + # NGC PyTorch 25.05 ships Python 3.12 and CUDA 12.5. Bump as needed. + - "NGC25.05+PY312" + include: + - environment: NGC25.05+PY312 + base: "nvcr.io/nvidia/pytorch:25.05-py3" + container: + image: ${{ matrix.base }} + options: --gpus all --env NVIDIA_DISABLE_REQUIRE=true + runs-on: [self-hosted, linux, x64, common] + steps: + - uses: actions/checkout@v6 + - name: Install dependencies + run: | + python -m pip install --upgrade pip wheel + python -c "import sys; assert sys.version_info >= (3, 10), f'Python >= 3.10 required for nvsubquadratic, got {sys.version}'" + python -m pip install -r requirements-dev.txt + python -m pip install -e . + # Install nvsubquadratic with --no-deps: the default torch_fft path needs only + # torch + einops + omegaconf, and nvsubquadratic pins torch>=2.10,<2.11 which can + # clash with the container's torch. To exercise the accelerated fused CUDA + # kernels instead, install the [cuda] extra (subquadratic-ops-torch-cu12, builds + # against the container's nvcc) and set fft_backend="subq_ops" in the tests. + python -m pip install omegaconf + python -m pip install --no-deps 'nvsubquadratic>=0.1.1' + python -m pip list + shell: bash + - name: Verify CUDA + nvsubquadratic + run: | + nvidia-smi + python -c "import torch; assert torch.cuda.is_available(); print('CUDA OK:', torch.cuda.get_device_name(0))" + python -c "from monai.networks.blocks.hyena import is_nvsubquadratic_available; \ + assert is_nvsubquadratic_available(), 'nvsubquadratic must be importable'" + shell: bash + - name: Run Hyena test suite (CUDA + no-CUDA) + run: | + python -m pytest -v \ + tests/networks/blocks/test_hyena_block.py \ + tests/networks/nets/test_hyena_nd_unetr.py \ + tests/networks/nets/test_swin_unetr.py + shell: bash diff --git a/CHANGELOG.md b/CHANGELOG.md index 419210a903..c8731ddb42 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,11 @@ All notable changes to MONAI are documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ## [Unreleased] +### Added +* `HyenaMixer`, `HyenaTransformerBlock`, and `DepthwiseFFTConv{2,3}d` in `monai.networks.blocks`: subquadratic O(N log N) alternatives to windowed self-attention, backed by the HyenaND operator from the optional `nvsubquadratic` package. +* `HyenaNDUNETR` (`monai.networks.nets.HyenaNDUNETR`): thin `SwinUNETR` subclass with a `get_variant(name)` classmethod for the three Hyena variants (`HHHH`, `HAHA`, `HHAA`) from the NeurIPS 2026 paper "Native Multi-Dimensional Subquadratic Operators via Input Dependent Long Convolutions" (paper id 26539). +* `SwinUNETR.use_hyena` and `SwinUNETR.hyena_stages` kwargs to thread HyenaND blocks through any subset of Swin stages. Default `use_hyena=False` preserves bit-identical forward behavior of the existing code path. +* New `[hyena]` extras_require in setup.cfg (`pip install monai[hyena]`). ## [1.6.0] - 2026-06-12 diff --git a/docs/source/installation.md b/docs/source/installation.md index 5123bc3e6b..df5705ed05 100644 --- a/docs/source/installation.md +++ b/docs/source/installation.md @@ -254,10 +254,15 @@ Since MONAI v0.2.0, the extras syntax such as `pip install 'monai[nibabel]'` is - The options are ``` -[nibabel, skimage, scipy, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops, transformers, mlflow, clearml, matplotlib, tensorboardX, tifffile, imagecodecs, pyyaml, fire, jsonschema, ninja, pynrrd, pydicom, h5py, nni, optuna, onnx, onnxruntime, zarr, lpips, pynvml, huggingface_hub] +[nibabel, skimage, scipy, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops, transformers, mlflow, clearml, matplotlib, tensorboardX, tifffile, imagecodecs, pyyaml, fire, jsonschema, ninja, pynrrd, pydicom, h5py, nni, optuna, onnx, onnxruntime, zarr, lpips, pynvml, huggingface_hub, hyena] ``` which correspond to `nibabel`, `scikit-image`,`scipy`, `pillow`, `tensorboard`, -`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas`, `einops`, `transformers`, `mlflow`, `clearml`, `matplotlib`, `tensorboardX`, `tifffile`, `imagecodecs`, `pyyaml`, `fire`, `jsonschema`, `ninja`, `pynrrd`, `pydicom`, `h5py`, `nni`, `optuna`, `onnx`, `onnxruntime`, `zarr`, `lpips`, `nvidia-ml-py`, `huggingface_hub` and `pyamg` respectively. +`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas`, `einops`, `transformers`, `mlflow`, `clearml`, `matplotlib`, `tensorboardX`, `tifffile`, `imagecodecs`, `pyyaml`, `fire`, `jsonschema`, `ninja`, `pynrrd`, `pydicom`, `h5py`, `nni`, `optuna`, `onnx`, `onnxruntime`, `zarr`, `lpips`, `nvidia-ml-py`, `huggingface_hub`, `pyamg`, and `nvsubquadratic` respectively. + +The `hyena` extra pulls in [`nvsubquadratic`](https://github.com/NVIDIA-BioNeMo/nvSubquadratic), +required by `HyenaNDUNETR` / `HyenaMixer` / `HyenaTransformerBlock` (subquadratic +O(N log N) alternatives to windowed self-attention). Install with +`pip install 'monai[hyena]'`. - `pip install 'monai[all]'` installs all the optional dependencies. diff --git a/docs/source/networks.rst b/docs/source/networks.rst index de0aece3f7..e7709678b7 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -129,6 +129,23 @@ Blocks .. autoclass:: TransformerBlock :members: +`Hyena Mixer` +~~~~~~~~~~~~~ +.. autoclass:: HyenaMixer + :members: + +`Hyena Transformer Block` +~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: HyenaTransformerBlock + :members: + +`Depthwise FFT Convolution` +~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: DepthwiseFFTConv2d + :members: +.. autoclass:: DepthwiseFFTConv3d + :members: + `UNETR Block` ~~~~~~~~~~~~~ .. autoclass:: UnetrBasicBlock @@ -591,6 +608,11 @@ Nets .. autoclass:: SwinUNETR :members: +`HyenaNDUNETR` +~~~~~~~~~~~~~~ +.. autoclass:: HyenaNDUNETR + :members: + `BasicUNet` ~~~~~~~~~~~ .. autoclass:: BasicUNet diff --git a/monai/networks/blocks/__init__.py b/monai/networks/blocks/__init__.py index 22af82d316..5932aba7fe 100644 --- a/monai/networks/blocks/__init__.py +++ b/monai/networks/blocks/__init__.py @@ -26,6 +26,13 @@ from .encoder import BaseEncoder from .fcn import FCN, GCN, MCFCN, Refine from .feature_pyramid_network import ExtraFPNBlock, FeaturePyramidNetwork, LastLevelMaxPool, LastLevelP6P7 +from .hyena import ( + DepthwiseFFTConv2d, + DepthwiseFFTConv3d, + HyenaMixer, + HyenaTransformerBlock, + is_nvsubquadratic_available, +) from .localnet_block import LocalNetDownSampleBlock, LocalNetFeatureExtractorBlock, LocalNetUpSampleBlock from .mednext_block import MedNeXtBlock, MedNeXtDownBlock, MedNeXtOutBlock, MedNeXtUpBlock from .mlp import MLPBlock diff --git a/monai/networks/blocks/hyena.py b/monai/networks/blocks/hyena.py new file mode 100644 index 0000000000..5ae0164aa3 --- /dev/null +++ b/monai/networks/blocks/hyena.py @@ -0,0 +1,551 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +HyenaND-based building blocks for MONAI networks. + +These blocks provide a subquadratic O(N log N) alternative to windowed self-attention +in transformer-style segmentation networks. The operator is HyenaND from the +``nvsubquadratic`` package, gated through a thin :class:`HyenaMixer` and wrapped in +the conventional pre-norm / MLP residual pattern by :class:`HyenaTransformerBlock`. + +The supporting :class:`DepthwiseFFTConv2d` / :class:`DepthwiseFFTConv3d` classes are +drop-in depthwise convolutions implemented via FFT. They preserve the +``nn.Conv{2,3}d`` weight layout and ``isinstance`` relationship but route the forward +pass through ``torch.fft.rfftn`` to avoid PyTorch's INT32 unfold limit, which caps +``F.conv3d`` at ROI ~128 for typical medical-imaging channel counts. + +``nvsubquadratic`` is an optional dependency. The Hyena classes raise ``ImportError`` +with an install hint at construction time if the library is unavailable; the +FFT-conv classes have no such dependency and always work. +""" + +from __future__ import annotations + +import torch +import torch.nn as nn +from torch.nn import LayerNorm + +from monai.networks.blocks.mlp import MLPBlock +from monai.networks.layers.drop_path import DropPath +from monai.utils import optional_import + +__all__ = [ + "DepthwiseFFTConv2d", + "DepthwiseFFTConv3d", + "HyenaMixer", + "HyenaTransformerBlock", + "is_nvsubquadratic_available", +] + +# Optional ``nvsubquadratic`` symbols. Resolved at module-import time; the missing-dep +# error is raised lazily inside the consuming class ``__init__``. Every symbol the Hyena +# classes use carries its own availability flag so a partial / broken install (e.g. +# ``lazy_config`` present but ``modules.hyena_nd`` missing) reports unavailable rather than +# failing later with an opaque ``AttributeError``. +_LazyConfig, _has_lazyconfig = optional_import("nvsubquadratic.lazy_config", name="LazyConfig") +_instantiate, _has_instantiate = optional_import("nvsubquadratic.lazy_config", name="instantiate") +_Hyena, _has_hyena = optional_import("nvsubquadratic.modules.hyena_nd", name="Hyena") +_CKConvND, _has_ckconv = optional_import("nvsubquadratic.modules.ckconv_nd", name="CKConvND") +_SIRENKernelND, _has_siren = optional_import("nvsubquadratic.modules.kernels_nd", name="SIRENKernelND") +_GaussianModulationND, _has_gaussian = optional_import("nvsubquadratic.modules.masks_nd", name="GaussianModulationND") +_has_nvsubq = all((_has_lazyconfig, _has_instantiate, _has_hyena, _has_ckconv, _has_siren, _has_gaussian)) + +_NVSUBQ_INSTALL_HINT = ( + "Hyena operators require the optional ``nvsubquadratic`` package. " + "Install with ``pip install nvsubquadratic`` " + "or ``pip install monai[hyena]``." +) + + +def is_nvsubquadratic_available() -> bool: + """Return ``True`` if the optional ``nvsubquadratic`` package is importable.""" + return bool(_has_nvsubq) + + +# --------------------------------------------------------------------------- +# Depthwise FFT convolutions — no nvsubquadratic dependency +# --------------------------------------------------------------------------- + + +class _DepthwiseFFTForward: + """Mixin providing FFT-based forward for depthwise ``nn.Conv{2,3}d`` subclasses. + + No ``nn.Module`` parent: module machinery comes from ``nn.Conv2d`` / ``nn.Conv3d`` + in the concrete subclasses. Placed first in the MRO so ``forward`` resolves here + (FFT) rather than to ``nn.Conv{2,3}d.forward`` (im2col / unfold). + + Avoids PyTorch's im2col INT32 overflow, which caps ``F.conv3d`` at ROI ~128 for + typical medical-imaging channel counts. There is no spatial-size restriction. + + ``fft_chunk_size > 0`` enables channel-chunked FFT to cap peak memory: + + peak ≈ (B × chunk × spatial × 4 + B × chunk × rfft_spatial × 8) bytes + + instead of the full ``(B × C × ...)`` allocation. + """ + + _spatial_dims: int # set by subclasses + fft_chunk_size: int = 0 # 0 = no chunking; set in subclass __init__ + + def forward(self, x: torch.Tensor) -> torch.Tensor: + spatial = x.shape[2:] + kernel_shape = self.weight.shape[2:] # type: ignore[attr-defined] + fft_dims = tuple(range(-self._spatial_dims, 0)) + fft_size = [s + k - 1 for s, k in zip(spatial, kernel_shape)] + in_dtype = x.dtype + + slices = (slice(None), slice(None)) + tuple(slice(k // 2, k // 2 + s) for s, k in zip(spatial, kernel_shape)) + + chunk = getattr(self, "fft_chunk_size", 0) + if chunk > 0 and x.shape[1] > chunk: + parts = [] + for c0 in range(0, x.shape[1], chunk): + c1 = min(c0 + chunk, x.shape[1]) + xc = x[:, c0:c1].float() + kc = self.weight[c0:c1].squeeze(1).float() # type: ignore[attr-defined] + kc = kc.flip(list(range(1, self._spatial_dims + 1))) + xc_fft = torch.fft.rfftn(xc, s=fft_size, dim=fft_dims) + kc_fft = torch.fft.rfftn(kc, s=fft_size, dim=fft_dims) + out_fft = xc_fft * kc_fft.unsqueeze(0) + del xc_fft, kc_fft + out_c = torch.fft.irfftn(out_fft, s=fft_size, dim=fft_dims) + del out_fft + parts.append(out_c[slices].to(in_dtype)) + del out_c + return torch.cat(parts, dim=1) + + x_f32 = x.float() + k_f32 = self.weight.squeeze(1).float() # type: ignore[attr-defined] + # PyTorch ``F.conv*`` computes cross-correlation; FFT computes convolution. + # Flip the kernel so the FFT output matches ``Conv{2,3}d`` exactly. + k_f32 = k_f32.flip(list(range(1, self._spatial_dims + 1))) + + x_fft = torch.fft.rfftn(x_f32, s=fft_size, dim=fft_dims) + k_fft = torch.fft.rfftn(k_f32, s=fft_size, dim=fft_dims) + + out_fft = x_fft * k_fft.unsqueeze(0) + out = torch.fft.irfftn(out_fft, s=fft_size, dim=fft_dims) + return out[slices].to(in_dtype) + + +def _validate_depthwise_fft_args( + in_channels: int, out_channels: int, kernel_size: int, groups: int, padding: int, bias: bool +) -> None: + """Validate the constructor arguments shared by ``DepthwiseFFTConv{2,3}d``. + + The FFT forward only implements depthwise, bias-free, ``"same"``-style convolution: it + crops the full convolution back to the input spatial size assuming ``padding == + kernel_size // 2`` with an odd kernel. Reject anything else up front rather than return a + silently wrong shape. + """ + if not (in_channels == out_channels == groups): + raise ValueError( + "DepthwiseFFTConv only supports depthwise (groups == in_channels == out_channels); " + f"got in_channels={in_channels}, out_channels={out_channels}, groups={groups}" + ) + if bias: + raise ValueError("bias is not supported in DepthwiseFFTConv") + if kernel_size % 2 == 0 or padding != kernel_size // 2: + raise ValueError( + "DepthwiseFFTConv only supports 'same'-style padding: kernel_size must be odd and " + f"padding must equal kernel_size // 2; got kernel_size={kernel_size}, padding={padding}. " + "The FFT forward crops to the input spatial size and does not implement general padding." + ) + + +class DepthwiseFFTConv2d(_DepthwiseFFTForward, nn.Conv2d): + """2-D depthwise FFT convolution. ``isinstance(x, nn.Conv2d)`` remains ``True``. + + Drop-in replacement for an ``nn.Conv2d`` with ``groups == in_channels == out_channels`` + and ``bias=False``. Useful as the short-conv inside :class:`HyenaMixer` at large 2-D + inputs, where ``F.conv2d`` would not yet hit the INT32 limit but the unified + Conv2d/Conv3d API is convenient. + + Only ``"same"``-style padding is supported: ``padding`` must equal ``kernel_size // 2`` + (and ``kernel_size`` must be odd). The FFT forward crops its output back to the input + spatial size and does not implement general (e.g. ``"valid"``) padding. + """ + + _spatial_dims = 2 + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + groups: int, + padding: int, + bias: bool = False, + fft_chunk_size: int = 0, + ) -> None: + _validate_depthwise_fft_args(in_channels, out_channels, kernel_size, groups, padding, bias) + nn.Conv2d.__init__( + self, in_channels, out_channels, kernel_size, stride=1, padding=padding, groups=groups, bias=False + ) + self.fft_chunk_size = fft_chunk_size + + +class DepthwiseFFTConv3d(_DepthwiseFFTForward, nn.Conv3d): + """3-D depthwise FFT convolution. ``isinstance(x, nn.Conv3d)`` remains ``True``. + + Drop-in replacement for an ``nn.Conv3d`` with ``groups == in_channels == out_channels`` + and ``bias=False``. Avoids the INT32 unfold limit that prevents ``F.conv3d`` from + running at ROI > ~128 for typical medical-imaging channel counts. + + Only ``"same"``-style padding is supported: ``padding`` must equal ``kernel_size // 2`` + (and ``kernel_size`` must be odd). The FFT forward crops its output back to the input + spatial size and does not implement general (e.g. ``"valid"``) padding. + """ + + _spatial_dims = 3 + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + groups: int, + padding: int, + bias: bool = False, + fft_chunk_size: int = 0, + ) -> None: + _validate_depthwise_fft_args(in_channels, out_channels, kernel_size, groups, padding, bias) + nn.Conv3d.__init__( + self, in_channels, out_channels, kernel_size, stride=1, padding=padding, groups=groups, bias=False + ) + self.fft_chunk_size = fft_chunk_size + + +# --------------------------------------------------------------------------- +# HyenaMixer — QKV-projected gated long-conv mixer +# --------------------------------------------------------------------------- + + +class HyenaMixer(nn.Module): + """QKV-projected gated long-convolution mixer using HyenaND. + + Replaces self-attention with the HyenaND operator from ``nvsubquadratic``, providing + a global receptive field at O(N log N) cost via FFT. ``HyenaMixer`` matches the + channels-last layout ``[B, *spatial, C]`` expected by Swin-style transformer + blocks; the underlying HyenaND operator handles the 2-D / 3-D distinction. + + The HyenaND FFT path requires float32 precision, so the inner mixer call is + wrapped in ``torch.amp.autocast("cuda", enabled=False)``. Inputs are cast to + float32 before the mixer call and back to the original dtype after, so the block + is transparent under ``torch.autocast``. + + Args: + dim: hidden dimension. + spatial_dims: 2 or 3. + use_rope: kept for forward-compatibility with older configurations. + ``nvsubquadratic`` removed RoPE from HyenaND on 2026-04-27 and this kwarg + is now a silent no-op. + apply_qk_norm: whether to apply per-channel LayerNorm to Q (and to K when the + first gate is the identity, as is the default here). + short_conv_kernel_size: depthwise short-convolution kernel size on the + concatenated ``[Q; K; V]`` tensor. + kernel_mlp_hidden_dim: hidden dim of the SIREN implicit kernel MLP. + kernel_num_layers: depth of the SIREN implicit kernel. + kernel_omega_0: SIREN frequency. Default 10.0 (stable). Higher values allow + higher-frequency kernels but reduce training stability. + kernel_l_cache: SIREN coordinate-grid cache size per spatial dim. Memory + scales as ``(2L-1) ** D × D × 4`` bytes; set to ``>= max(spatial_dims)`` + to pre-allocate. + mask_max_attenuation: Gaussian-modulation attenuation at the grid boundary + for the widest channel (0–1). Default 0.95. + fft_padding: ``"circular"`` or ``"zero"``. + grid_type: ``"single"`` (kernel size = input size; required for circular + padding) or ``"double"`` (kernel size = 2× input size; only valid with + ``"zero"`` padding). + use_chunked_fftconv: chunk the FFT convolution by channel to reduce peak + memory ~26% with ~11% compute overhead. Requires ``fft_padding="zero"``. + use_fft_short_conv: replace the depthwise short conv with + :class:`DepthwiseFFTConv{2,3}d`, eliminating the INT32 unfold limit and + enabling unlimited ROI sizes. Adds ~11% compute overhead. + short_conv_fft_chunk_size: channel chunk size for the FFT short conv + (0 = no chunking). + + Raises: + ImportError: if ``nvsubquadratic`` is not installed. + ValueError: on invalid ``spatial_dims`` / ``fft_padding`` / ``grid_type`` + combinations. + """ + + def __init__( + self, + dim: int, + spatial_dims: int = 3, + use_rope: bool = True, + apply_qk_norm: bool = True, + short_conv_kernel_size: int = 3, + kernel_mlp_hidden_dim: int = 32, + kernel_num_layers: int = 3, + kernel_omega_0: float = 10.0, + kernel_l_cache: int = 32, + mask_max_attenuation: float = 0.95, + fft_padding: str = "circular", + grid_type: str = "single", + use_chunked_fftconv: bool = False, + use_fft_short_conv: bool = False, + short_conv_fft_chunk_size: int = 0, + ) -> None: + super().__init__() + + if not _has_nvsubq: + raise ImportError(_NVSUBQ_INSTALL_HINT) + + if fft_padding not in ("circular", "zero"): + raise ValueError(f"fft_padding must be 'circular' or 'zero', got '{fft_padding}'") + if grid_type not in ("single", "double"): + raise ValueError(f"grid_type must be 'single' or 'double', got '{grid_type}'") + if fft_padding == "circular" and grid_type != "single": + raise ValueError( + "fft_padding='circular' requires grid_type='single' " + f"(kernel size must match input size for periodic convolution); got grid_type='{grid_type}'" + ) + if use_chunked_fftconv and fft_padding != "zero": + raise ValueError( + "use_chunked_fftconv=True requires fft_padding='zero'; " f"got fft_padding='{fft_padding}'" + ) + + self.dim = dim + self.spatial_dims = spatial_dims + + if use_fft_short_conv: + if spatial_dims == 2: + conv_class = DepthwiseFFTConv2d + elif spatial_dims == 3: + conv_class = DepthwiseFFTConv3d + else: + raise ValueError(f"spatial_dims must be 2 or 3, got {spatial_dims}") + else: + if spatial_dims == 2: + conv_class = nn.Conv2d + elif spatial_dims == 3: + conv_class = nn.Conv3d + else: + raise ValueError(f"spatial_dims must be 2 or 3, got {spatial_dims}") + + global_conv_cfg = _LazyConfig(_CKConvND)( + data_dim=spatial_dims, + hidden_dim=dim, + kernel_cfg=_LazyConfig(_SIRENKernelND)( + data_dim=spatial_dims, + out_dim=dim, + mlp_hidden_dim=kernel_mlp_hidden_dim, + num_layers=kernel_num_layers, + embedding_dim=kernel_mlp_hidden_dim, + omega_0=kernel_omega_0, + L_cache=kernel_l_cache, + use_bias=True, + hidden_omega_0=1.0, + ), + mask_cfg=_LazyConfig(_GaussianModulationND)( + data_dim=spatial_dims, + num_channels=dim, + min_attenuation_at_step=0.1, + max_attenuation_at_limit=mask_max_attenuation, + init_extent=1.0, + parametrization="direct", + ), + grid_type=grid_type, + fft_padding=fft_padding, + use_chunked_fftconv=use_chunked_fftconv, + ) + + short_conv_kwargs: dict = dict( + in_channels=3 * dim, + out_channels=3 * dim, + kernel_size=short_conv_kernel_size, + groups=3 * dim, + padding=short_conv_kernel_size // 2, + bias=False, + ) + if use_fft_short_conv and short_conv_fft_chunk_size > 0: + short_conv_kwargs["fft_chunk_size"] = short_conv_fft_chunk_size + short_conv_cfg = _LazyConfig(conv_class)(**short_conv_kwargs) + + # ``use_rope`` retained on the API only; ``nvsubquadratic`` removed RoPE + # from ``Hyena.__init__`` on 2026-04-27. Saving the flag here keeps caller + # introspection intact ("did the user ask for RoPE?") while not affecting + # the constructed operator. + self._use_rope_requested = use_rope + + self.mixer = _instantiate( + _LazyConfig(_Hyena)( + global_conv_cfg=global_conv_cfg, + short_conv_cfg=short_conv_cfg, + gate_nonlinear_cfg=_LazyConfig(nn.Identity)(), + pixelhyena_norm_cfg=_LazyConfig(nn.GroupNorm)(num_groups=1, num_channels=dim), + qk_norm_cfg=_LazyConfig(nn.LayerNorm)(normalized_shape=dim) if apply_qk_norm else None, + ) + ) + + self.qkv_proj = nn.Linear(dim, 3 * dim, bias=False) + self.out_proj = nn.Linear(dim, dim, bias=False) + self._init_weights() + + def _init_weights(self) -> None: + nn.init.normal_(self.qkv_proj.weight, std=0.02) + nn.init.normal_(self.out_proj.weight, std=0.02) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass. + + Args: + x: tensor of shape ``[batch, *spatial, dim]``. + + Returns: + Tensor of the same shape and dtype as the input. + """ + qkv = self.qkv_proj(x) + q, k, v = torch.chunk(qkv, 3, dim=-1) + + # HyenaND requires float32 internally; disable autocast for the mixer only, + # then restore the original (autocast) dtype afterwards. + with torch.amp.autocast("cuda", enabled=False): + q = q.float() + k = k.float() + v = v.float() + x = self.mixer(q, k, v) + x = x.to(qkv.dtype) + return self.out_proj(x) + + +# --------------------------------------------------------------------------- +# HyenaTransformerBlock — pre-norm Hyena + MLP residual +# --------------------------------------------------------------------------- + + +class HyenaTransformerBlock(nn.Module): + """Pre-norm transformer block with HyenaND in place of self-attention. + + Sandwiches a :class:`HyenaMixer` and an MLP between :class:`~torch.nn.LayerNorm` + layers in the standard transformer residual pattern, with optional gradient + checkpointing on each half. + + Args: + dim: number of feature channels. + spatial_dims: 2 or 3. + mlp_ratio: hidden / input ratio for the MLP. + drop: dropout rate inside the MLP. + drop_path: stochastic-depth rate for the MLP residual. + act_layer: activation name passed to :class:`monai.networks.blocks.MLPBlock`. + norm_layer: normalization class (default :class:`~torch.nn.LayerNorm`). + use_checkpoint: enable gradient checkpointing on the mixer and MLP halves. + use_rope: forward-compatibility flag for older configs; no-op after the + ``nvsubquadratic`` 2026-04-27 RoPE removal. + apply_qk_norm: per-channel LayerNorm on Q (and K when the first gate is + identity, as is the default). + hyena_kernel_size: short-convolution kernel size on the QKV tensor. + hyena_kernel_mlp_dim: SIREN kernel MLP hidden dimension. + hyena_kernel_layers: SIREN kernel depth. + hyena_mask_max_attenuation: Gaussian-modulation boundary attenuation (0–1). + hyena_fft_padding: ``"circular"`` or ``"zero"``. + hyena_grid_type: ``"single"`` or ``"double"``. + hyena_use_chunked_fft: enable chunked FFT (requires zero padding). + hyena_use_fft_short_conv: use FFT for the short conv (no INT32 limit). + hyena_omega_0: SIREN ``omega_0``. Default 10.0. + hyena_l_cache: SIREN coordinate-grid cache size per dim. + hyena_short_conv_fft_chunks: channel chunk size for the FFT short conv. + + Raises: + ImportError: if ``nvsubquadratic`` is not installed. + """ + + def __init__( + self, + dim: int, + spatial_dims: int = 3, + mlp_ratio: float = 4.0, + drop: float = 0.0, + drop_path: float = 0.0, + act_layer: str = "GELU", + norm_layer: type[LayerNorm] = nn.LayerNorm, # type: ignore[assignment] + use_checkpoint: bool = False, + use_rope: bool = True, + apply_qk_norm: bool = True, + hyena_kernel_size: int = 3, + hyena_kernel_mlp_dim: int = 32, + hyena_kernel_layers: int = 3, + hyena_mask_max_attenuation: float = 0.95, + hyena_fft_padding: str = "circular", + hyena_grid_type: str = "single", + hyena_use_chunked_fft: bool = False, + hyena_use_fft_short_conv: bool = False, + hyena_omega_0: float = 10.0, + hyena_l_cache: int = 32, + hyena_short_conv_fft_chunks: int = 0, + ) -> None: + super().__init__() + self.dim = dim + self.spatial_dims = spatial_dims + self.mlp_ratio = mlp_ratio + self.use_checkpoint = use_checkpoint + + self.norm1 = norm_layer(dim) + self.mixer = HyenaMixer( + dim=dim, + spatial_dims=spatial_dims, + use_rope=use_rope, + apply_qk_norm=apply_qk_norm, + short_conv_kernel_size=hyena_kernel_size, + kernel_mlp_hidden_dim=hyena_kernel_mlp_dim, + kernel_num_layers=hyena_kernel_layers, + kernel_omega_0=hyena_omega_0, + kernel_l_cache=hyena_l_cache, + mask_max_attenuation=hyena_mask_max_attenuation, + fft_padding=hyena_fft_padding, + grid_type=hyena_grid_type, + use_chunked_fftconv=hyena_use_chunked_fft, + use_fft_short_conv=hyena_use_fft_short_conv, + short_conv_fft_chunk_size=hyena_short_conv_fft_chunks, + ) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = MLPBlock( + hidden_size=dim, mlp_dim=mlp_hidden_dim, act=act_layer, dropout_rate=drop, dropout_mode="swin" + ) + + def forward_part1(self, x: torch.Tensor) -> torch.Tensor: + x = self.norm1(x) + return self.mixer(x) + + def forward_part2(self, x: torch.Tensor) -> torch.Tensor: + return self.drop_path(self.mlp(self.norm2(x))) + + def forward(self, x: torch.Tensor, mask_matrix: torch.Tensor | None = None) -> torch.Tensor: + """Forward pass. + + Args: + x: input tensor of shape ``[batch, *spatial, dim]``. + mask_matrix: unused; accepted for signature parity with Swin's + ``WindowAttention``-based block so the two can be swapped at the + ``BasicLayer`` level without per-call branching. + + Returns: + Tensor of the same shape as the input. + """ + del mask_matrix + shortcut = x + if self.use_checkpoint: + x = torch.utils.checkpoint.checkpoint(self.forward_part1, x, use_reentrant=False) + else: + x = self.forward_part1(x) + x = shortcut + self.drop_path(x) + + if self.use_checkpoint: + x = x + torch.utils.checkpoint.checkpoint(self.forward_part2, x, use_reentrant=False) + else: + x = x + self.forward_part2(x) + return x diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index c1917e5293..fc0b33a0f0 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -53,6 +53,7 @@ from .generator import Generator from .highresnet import HighResBlock, HighResNet from .hovernet import Hovernet, HoVernet, HoVerNet, HoverNet +from .hyena_nd_unetr import HyenaNDUNETR from .masked_autoencoder_vit import MaskedAutoEncoderViT from .mednext import ( MedNeXt, diff --git a/monai/networks/nets/hyena_nd_unetr.py b/monai/networks/nets/hyena_nd_unetr.py new file mode 100644 index 0000000000..b6d8c120ac --- /dev/null +++ b/monai/networks/nets/hyena_nd_unetr.py @@ -0,0 +1,151 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +HyenaNDUNETR: SwinUNETR with the HyenaND subquadratic operator in place of (or +mixed with) windowed self-attention. + +This is a thin convenience subclass of :class:`monai.networks.nets.SwinUNETR` whose +defaults make Hyena placement explicit: + +* ``use_hyena`` is forced to ``True``. +* ``hyena_stages`` is **required** -- callers must explicitly declare which of the + four Swin stages run HyenaND vs windowed attention. + +The classmethod :meth:`HyenaNDUNETR.get_variant` provides the three Hyena +variants from Table 4 of the NeurIPS 2026 paper "Native Multi-Dimensional Subquadratic +Operators via Input Dependent Long Convolutions" (paper id 26539): + +========== ================================= ============================== +Variant ``hyena_stages`` Notes +========== ================================= ============================== +``HHHH`` ``(True, True, True, True)`` Hyena at every Swin stage +``HAHA`` ``(True, False, True, False)`` striped/interleaved +``HHAA`` ``(True, True, False, False)`` paper-best (outer Hyena, inner attention) +========== ================================= ============================== + +``AAAA`` (pure attention) is intentionally not exposed here -- it is plain +:class:`SwinUNETR` and constructing a "HyenaNDUNETR" with no Hyena stages would be a +contradiction. + +Requires the optional ``nvsubquadratic`` package; install with +``pip install monai[hyena]``. +""" + +from __future__ import annotations + +from collections.abc import Sequence + +from monai.networks.nets.swin_unetr import SwinUNETR + +__all__ = ["HyenaNDUNETR"] + + +# Per-stage Hyena patterns for the three paper variants. +PAPER_VARIANTS: dict[str, tuple[bool, ...]] = { + "HHHH": (True, True, True, True), + "HAHA": (True, False, True, False), + "HHAA": (True, True, False, False), +} + + +class HyenaNDUNETR(SwinUNETR): + """SwinUNETR with HyenaND replacing windowed self-attention at selected stages. + + See the module docstring for the paper-variant table and + :meth:`get_variant` for a convenience constructor matching the NeurIPS + 2026 paper. All other kwargs are forwarded to :class:`SwinUNETR`. + + Args: + in_channels: dimension of input channels. + out_channels: dimension of output channels. + hyena_stages: required 4-tuple of bools, one per Swin stage. At least one + element must be ``True`` (otherwise use :class:`SwinUNETR` directly). + feature_size: dimension of network feature size. Must be a multiple of 12 + (inherited from :class:`SwinUNETR`). + **kwargs: forwarded to :class:`SwinUNETR`. + + Raises: + ValueError: if ``hyena_stages`` is missing, has the wrong length, or has no + ``True`` element. + ImportError: if the optional ``nvsubquadratic`` package is not installed. + """ + + def __init__( + self, in_channels: int, out_channels: int, hyena_stages: Sequence[bool], feature_size: int = 48, **kwargs + ) -> None: + if hyena_stages is None: + raise ValueError( + "HyenaNDUNETR requires `hyena_stages` (a 4-tuple of bools); " + "use SwinUNETR directly for pure attention." + ) + stages_tuple = tuple(bool(s) for s in hyena_stages) + if len(stages_tuple) != 4: + raise ValueError( + f"hyena_stages must have length 4 (one bool per Swin stage); got length {len(stages_tuple)}." + ) + if not any(stages_tuple): + raise ValueError( + "hyena_stages must enable HyenaND at at least one stage; " "use SwinUNETR directly for pure attention." + ) + + # ``use_hyena`` is forced True here; reject it in kwargs rather than silently + # override -- the subclass exists to make Hyena placement explicit. (``hyena_stages`` + # is an explicit parameter above, so it can never reach ``kwargs``.) + if "use_hyena" in kwargs: + raise TypeError("HyenaNDUNETR forces use_hyena=True; do not pass use_hyena via kwargs.") + + super().__init__( + in_channels=in_channels, + out_channels=out_channels, + feature_size=feature_size, + use_hyena=True, + hyena_stages=stages_tuple, + **kwargs, + ) + + @classmethod + def get_variant(cls, variant: str, **kwargs) -> HyenaNDUNETR: + """Build a :class:`HyenaNDUNETR` matching one of the NeurIPS 2026 paper variants. + + Args: + variant: one of ``"HHHH"``, ``"HAHA"``, ``"HHAA"`` (case-insensitive). + **kwargs: forwarded to :class:`HyenaNDUNETR.__init__`. Must include at + least ``in_channels`` and ``out_channels``. Must NOT include + ``hyena_stages`` (set by the variant). + + Returns: + A :class:`HyenaNDUNETR` with ``hyena_stages`` set per the variant. + + Raises: + ValueError: if ``variant`` is not one of the three known names, or if + ``hyena_stages`` is also passed via kwargs. + + Example:: + + >>> net = HyenaNDUNETR.get_variant( + ... "HHAA", + ... in_channels=1, + ... out_channels=29, + ... feature_size=48, + ... ) + """ + key = variant.upper() + if key not in PAPER_VARIANTS: + raise ValueError( + f"Unknown paper variant '{variant}'. " + f"Known variants: {sorted(PAPER_VARIANTS)}. " + "(AAAA is plain SwinUNETR; use that class directly.)" + ) + if "hyena_stages" in kwargs: + raise ValueError( + "get_variant sets hyena_stages from the variant name; " "do not also pass hyena_stages via kwargs." + ) + return cls(hyena_stages=PAPER_VARIANTS[key], **kwargs) diff --git a/monai/networks/nets/swin_unetr.py b/monai/networks/nets/swin_unetr.py index 0db2d50d26..fa944b0920 100644 --- a/monai/networks/nets/swin_unetr.py +++ b/monai/networks/nets/swin_unetr.py @@ -23,6 +23,7 @@ from monai.networks.blocks import MLPBlock as Mlp from monai.networks.blocks import PatchEmbed, UnetOutBlock, UnetrBasicBlock, UnetrUpBlock +from monai.networks.blocks.hyena import HyenaTransformerBlock from monai.networks.layers import DropPath, trunc_normal_ from monai.utils import ensure_tuple_rep, look_up_option, optional_import @@ -84,6 +85,19 @@ def __init__( spatial_dims: int = 3, downsample: str | nn.Module = "merging", use_v2: bool = False, + use_hyena: bool = False, + hyena_stages: Sequence[bool] | None = None, + hyena_kernel_size: int = 3, + hyena_kernel_mlp_dim: int = 32, + hyena_kernel_layers: int = 3, + hyena_mask_max_attenuation: float = 0.95, + hyena_fft_padding: str = "circular", + hyena_grid_type: str = "single", + hyena_use_chunked_fft: bool = False, + hyena_use_fft_short_conv: bool = False, + hyena_omega_0: float = 10.0, + hyena_l_cache: int = 32, + hyena_short_conv_fft_chunks: int = 0, ) -> None: """ Args: @@ -110,6 +124,31 @@ def __init__( user-specified `nn.Module` following the API defined in :py:class:`monai.networks.nets.PatchMerging`. The default is currently `"merging"` (the original version defined in v0.9.0). use_v2: using swinunetr_v2, which adds a residual convolution block at the beggining of each swin stage. + use_hyena: replace windowed self-attention with the HyenaND operator (subquadratic O(N log N) + global convolution) in every Swin stage. Default ``False`` keeps the model bit-identical + to the pre-HyenaND code path. Requires the optional ``nvsubquadratic`` package + (``pip install monai[hyena]``). When combined with ``hyena_stages``, the per-stage flag + overrides this master switch on a per-stage basis. + hyena_stages: optional 4-tuple of bools selecting which Swin stages use HyenaND vs windowed + attention. ``True`` at index ``i`` builds :class:`HyenaTransformerBlock` at stage ``i``, + ``False`` builds the conventional :class:`SwinTransformerBlock`. The four NeurIPS 2026 + paper variants are: ``None`` (AAAA, all attention, requires ``use_hyena=False``); + ``(True, True, True, True)`` (HHHH, equivalent to ``use_hyena=True``); ``(True, False, + True, False)`` (HAHA); ``(True, True, False, False)`` (HHAA, paper-best). + hyena_kernel_size: HyenaND short-convolution kernel size (depthwise on QKV). + hyena_kernel_mlp_dim: SIREN implicit-kernel MLP hidden dimension. + hyena_kernel_layers: SIREN implicit-kernel depth. + hyena_mask_max_attenuation: Gaussian-modulation boundary attenuation (0-1). + hyena_fft_padding: ``"circular"`` or ``"zero"``. ``"circular"`` was the paper-best setting. + hyena_grid_type: ``"single"`` (kernel = input size, required for circular) or ``"double"`` + (kernel = 2x input size, requires zero padding). + hyena_use_chunked_fft: enable chunked FFT for ~26 percent memory savings; requires + ``hyena_fft_padding="zero"``. + hyena_use_fft_short_conv: replace the short conv with :class:`DepthwiseFFTConv{2,3}d` to + eliminate the INT32 unfold limit and enable ROI > 128. + hyena_omega_0: SIREN frequency. Default 10.0 (stable). + hyena_l_cache: SIREN coordinate-grid cache size per spatial dim. + hyena_short_conv_fft_chunks: channel chunk size for the FFT short conv (0 = no chunking). Examples:: @@ -151,6 +190,8 @@ def __init__( raise ValueError("feature_size should be divisible by 12.") self.normalize = normalize + self.use_hyena = use_hyena + self.hyena_stages = tuple(bool(s) for s in hyena_stages) if hyena_stages is not None else None self.swinViT = SwinTransformer( in_chans=in_channels, @@ -170,6 +211,19 @@ def __init__( spatial_dims=spatial_dims, downsample=look_up_option(downsample, MERGING_MODE) if isinstance(downsample, str) else downsample, use_v2=use_v2, + use_hyena=use_hyena, + hyena_stages=self.hyena_stages, + hyena_kernel_size=hyena_kernel_size, + hyena_kernel_mlp_dim=hyena_kernel_mlp_dim, + hyena_kernel_layers=hyena_kernel_layers, + hyena_mask_max_attenuation=hyena_mask_max_attenuation, + hyena_fft_padding=hyena_fft_padding, + hyena_grid_type=hyena_grid_type, + hyena_use_chunked_fft=hyena_use_chunked_fft, + hyena_use_fft_short_conv=hyena_use_fft_short_conv, + hyena_omega_0=hyena_omega_0, + hyena_l_cache=hyena_l_cache, + hyena_short_conv_fft_chunks=hyena_short_conv_fft_chunks, ) self.encoder1 = UnetrBasicBlock( @@ -274,50 +328,52 @@ def __init__( self.out = UnetOutBlock(spatial_dims=spatial_dims, in_channels=feature_size, out_channels=out_channels) def load_from(self, weights): + """Load pretrained Swin weights into the matching submodules. + + When a stage uses :class:`HyenaTransformerBlock` instead of + :class:`SwinTransformerBlock`, the per-block ``load_from`` call is skipped for + that stage and a warning is issued -- HyenaND has a different parameter layout + and there are no compatible attention weights to copy. PatchMerging + downsample weights are still loaded for all stages (the downsample layer is + the same in both code paths). + """ + import warnings + layers1_0: BasicLayer = self.swinViT.layers1[0] # type: ignore[assignment] layers2_0: BasicLayer = self.swinViT.layers2[0] # type: ignore[assignment] layers3_0: BasicLayer = self.swinViT.layers3[0] # type: ignore[assignment] layers4_0: BasicLayer = self.swinViT.layers4[0] # type: ignore[assignment] wstate = weights["state_dict"] + def _stage_is_hyena(stage_layer: BasicLayer) -> bool: + first_block = next(iter(stage_layer.blocks.children())) + return isinstance(first_block, HyenaTransformerBlock) + with torch.no_grad(): self.swinViT.patch_embed.proj.weight.copy_(wstate["module.patch_embed.proj.weight"]) self.swinViT.patch_embed.proj.bias.copy_(wstate["module.patch_embed.proj.bias"]) - for bname, block in layers1_0.blocks.named_children(): - block.load_from(weights, n_block=bname, layer="layers1") # type: ignore[operator] - - if layers1_0.downsample is not None: - d = layers1_0.downsample - d.reduction.weight.copy_(wstate["module.layers1.0.downsample.reduction.weight"]) # type: ignore - d.norm.weight.copy_(wstate["module.layers1.0.downsample.norm.weight"]) # type: ignore - d.norm.bias.copy_(wstate["module.layers1.0.downsample.norm.bias"]) # type: ignore - - for bname, block in layers2_0.blocks.named_children(): - block.load_from(weights, n_block=bname, layer="layers2") # type: ignore[operator] - - if layers2_0.downsample is not None: - d = layers2_0.downsample - d.reduction.weight.copy_(wstate["module.layers2.0.downsample.reduction.weight"]) # type: ignore - d.norm.weight.copy_(wstate["module.layers2.0.downsample.norm.weight"]) # type: ignore - d.norm.bias.copy_(wstate["module.layers2.0.downsample.norm.bias"]) # type: ignore - - for bname, block in layers3_0.blocks.named_children(): - block.load_from(weights, n_block=bname, layer="layers3") # type: ignore[operator] - - if layers3_0.downsample is not None: - d = layers3_0.downsample - d.reduction.weight.copy_(wstate["module.layers3.0.downsample.reduction.weight"]) # type: ignore - d.norm.weight.copy_(wstate["module.layers3.0.downsample.norm.weight"]) # type: ignore - d.norm.bias.copy_(wstate["module.layers3.0.downsample.norm.bias"]) # type: ignore - - for bname, block in layers4_0.blocks.named_children(): - block.load_from(weights, n_block=bname, layer="layers4") # type: ignore[operator] - - if layers4_0.downsample is not None: - d = layers4_0.downsample - d.reduction.weight.copy_(wstate["module.layers4.0.downsample.reduction.weight"]) # type: ignore - d.norm.weight.copy_(wstate["module.layers4.0.downsample.norm.weight"]) # type: ignore - d.norm.bias.copy_(wstate["module.layers4.0.downsample.norm.bias"]) # type: ignore + + for layer_name, stage in [ + ("layers1", layers1_0), + ("layers2", layers2_0), + ("layers3", layers3_0), + ("layers4", layers4_0), + ]: + if _stage_is_hyena(stage): + warnings.warn( + f"Skipping {layer_name} block weights: stage uses HyenaTransformerBlock, " + "which has no compatible Swin attention weights. Blocks remain at their " + "random initialization.", + stacklevel=2, + ) + else: + for bname, block in stage.blocks.named_children(): + block.load_from(weights, n_block=bname, layer=layer_name) # type: ignore[operator] + if stage.downsample is not None: + d = stage.downsample + d.reduction.weight.copy_(wstate[f"module.{layer_name}.0.downsample.reduction.weight"]) # type: ignore + d.norm.weight.copy_(wstate[f"module.{layer_name}.0.downsample.norm.weight"]) # type: ignore + d.norm.bias.copy_(wstate[f"module.{layer_name}.0.downsample.norm.bias"]) # type: ignore @torch.jit.unused def _check_input_size(self, spatial_shape): @@ -856,6 +912,18 @@ def __init__( norm_layer: type[LayerNorm] = nn.LayerNorm, downsample: nn.Module | None = None, use_checkpoint: bool = False, + use_hyena: bool = False, + hyena_kernel_size: int = 3, + hyena_kernel_mlp_dim: int = 32, + hyena_kernel_layers: int = 3, + hyena_mask_max_attenuation: float = 0.95, + hyena_fft_padding: str = "circular", + hyena_grid_type: str = "single", + hyena_use_chunked_fft: bool = False, + hyena_use_fft_short_conv: bool = False, + hyena_omega_0: float = 10.0, + hyena_l_cache: int = 32, + hyena_short_conv_fft_chunks: int = 0, ) -> None: """ Args: @@ -871,6 +939,13 @@ def __init__( norm_layer: normalization layer. downsample: an optional downsampling layer at the end of the layer. use_checkpoint: use gradient checkpointing for reduced memory usage. + use_hyena: replace :class:`SwinTransformerBlock` with :class:`HyenaTransformerBlock` + in this stage. See :class:`SwinUNETR` for the per-stage selection mechanism. + hyena_kernel_size, hyena_kernel_mlp_dim, hyena_kernel_layers, + hyena_mask_max_attenuation, hyena_fft_padding, hyena_grid_type, + hyena_use_chunked_fft, hyena_use_fft_short_conv, hyena_omega_0, hyena_l_cache, + hyena_short_conv_fft_chunks: forwarded to :class:`HyenaTransformerBlock`. See its + docstring for semantics. """ super().__init__() @@ -879,24 +954,52 @@ def __init__( self.no_shift = tuple(0 for i in window_size) self.depth = depth self.use_checkpoint = use_checkpoint - self.blocks = nn.ModuleList( - [ - SwinTransformerBlock( - dim=dim, - num_heads=num_heads, - window_size=self.window_size, - shift_size=self.no_shift if (i % 2 == 0) else self.shift_size, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - drop=drop, - attn_drop=attn_drop, - drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, - norm_layer=norm_layer, - use_checkpoint=use_checkpoint, - ) - for i in range(depth) - ] - ) + self.use_hyena = use_hyena + if use_hyena: + self.blocks = nn.ModuleList( + [ + HyenaTransformerBlock( + dim=dim, + spatial_dims=len(self.window_size), + mlp_ratio=mlp_ratio, + drop=drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer, + use_checkpoint=use_checkpoint, + hyena_kernel_size=hyena_kernel_size, + hyena_kernel_mlp_dim=hyena_kernel_mlp_dim, + hyena_kernel_layers=hyena_kernel_layers, + hyena_mask_max_attenuation=hyena_mask_max_attenuation, + hyena_fft_padding=hyena_fft_padding, + hyena_grid_type=hyena_grid_type, + hyena_use_chunked_fft=hyena_use_chunked_fft, + hyena_use_fft_short_conv=hyena_use_fft_short_conv, + hyena_omega_0=hyena_omega_0, + hyena_l_cache=hyena_l_cache, + hyena_short_conv_fft_chunks=hyena_short_conv_fft_chunks, + ) + for i in range(depth) + ] + ) + else: + self.blocks = nn.ModuleList( + [ + SwinTransformerBlock( + dim=dim, + num_heads=num_heads, + window_size=self.window_size, + shift_size=self.no_shift if (i % 2 == 0) else self.shift_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer, + use_checkpoint=use_checkpoint, + ) + for i in range(depth) + ] + ) self.downsample = downsample if callable(self.downsample): self.downsample = downsample(dim=dim, norm_layer=norm_layer, spatial_dims=len(self.window_size)) @@ -910,7 +1013,8 @@ def forward(self, x): dp = int(np.ceil(d / window_size[0])) * window_size[0] hp = int(np.ceil(h / window_size[1])) * window_size[1] wp = int(np.ceil(w / window_size[2])) * window_size[2] - attn_mask = compute_mask([dp, hp, wp], window_size, shift_size, x.device) + # HyenaTransformerBlock ignores the attention mask; skip building it for Hyena stages. + attn_mask = None if self.use_hyena else compute_mask([dp, hp, wp], window_size, shift_size, x.device) for blk in self.blocks: x = blk(x, attn_mask) x = x.view(b, d, h, w, -1) @@ -924,7 +1028,8 @@ def forward(self, x): x = rearrange(x, "b c h w -> b h w c") hp = int(np.ceil(h / window_size[0])) * window_size[0] wp = int(np.ceil(w / window_size[1])) * window_size[1] - attn_mask = compute_mask([hp, wp], window_size, shift_size, x.device) + # HyenaTransformerBlock ignores the attention mask; skip building it for Hyena stages. + attn_mask = None if self.use_hyena else compute_mask([hp, wp], window_size, shift_size, x.device) for blk in self.blocks: x = blk(x, attn_mask) x = x.view(b, h, w, -1) @@ -961,6 +1066,19 @@ def __init__( spatial_dims: int = 3, downsample="merging", use_v2=False, + use_hyena: bool = False, + hyena_stages: Sequence[bool] | None = None, + hyena_kernel_size: int = 3, + hyena_kernel_mlp_dim: int = 32, + hyena_kernel_layers: int = 3, + hyena_mask_max_attenuation: float = 0.95, + hyena_fft_padding: str = "circular", + hyena_grid_type: str = "single", + hyena_use_chunked_fft: bool = False, + hyena_use_fft_short_conv: bool = False, + hyena_omega_0: float = 10.0, + hyena_l_cache: int = 32, + hyena_short_conv_fft_chunks: int = 0, ) -> None: """ Args: @@ -983,6 +1101,17 @@ def __init__( user-specified `nn.Module` following the API defined in :py:class:`monai.networks.nets.PatchMerging`. The default is currently `"merging"` (the original version defined in v0.9.0). use_v2: using swinunetr_v2, which adds a residual convolution block at the beginning of each swin stage. + use_hyena: build :class:`HyenaTransformerBlock` instead of :class:`SwinTransformerBlock` + in every stage. See :class:`SwinUNETR` for paper-variant patterns. + hyena_stages: optional per-stage override (4-tuple of bools); a stage flagged ``True`` + builds a Hyena block regardless of ``use_hyena``, and a stage flagged ``False`` + builds a Swin block regardless of ``use_hyena``. ``None`` falls back to ``use_hyena`` + for all stages. + hyena_kernel_size, hyena_kernel_mlp_dim, hyena_kernel_layers, + hyena_mask_max_attenuation, hyena_fft_padding, hyena_grid_type, + hyena_use_chunked_fft, hyena_use_fft_short_conv, hyena_omega_0, hyena_l_cache, + hyena_short_conv_fft_chunks: HyenaND configuration. See + :class:`monai.networks.blocks.HyenaTransformerBlock` for semantics. """ super().__init__() @@ -991,6 +1120,33 @@ def __init__( self.patch_norm = patch_norm self.window_size = window_size self.patch_size = patch_size + + # Per-stage Hyena selection: explicit ``hyena_stages`` overrides the master flag. + self._per_stage_hyena: list[bool] = ( + [bool(s) for s in hyena_stages] if hyena_stages is not None else [bool(use_hyena)] * self.num_layers + ) + if len(self._per_stage_hyena) != self.num_layers: + raise ValueError( + f"hyena_stages must have length {self.num_layers} (one bool per Swin stage); " + f"got length {len(self._per_stage_hyena)}." + ) + + # Legacy RoPE-divisibility guard: kept as defensive validation for callers that bypass the + # SwinUNETR-level ``feature_size % 12 == 0`` check. ``nvsubquadratic`` removed RoPE from + # the HyenaND operator on 2026-04-27, so this check is now slightly conservative; it does + # not affect any valid SwinUNETR configuration. + if any(self._per_stage_hyena): + div = 6 if spatial_dims == 3 else 4 + for i, use_h in enumerate(self._per_stage_hyena): + if use_h: + dim_at_layer = int(embed_dim * 2**i) + if dim_at_layer % div != 0: + raise ValueError( + f"For {spatial_dims}D Hyena, embed_dim * 2^layer must be divisible by {div}. " + f"At layer {i}, dim={dim_at_layer} is not. " + "Use embed_dim that is a multiple of 12 (the SwinUNETR default check)." + ) + self.patch_embed = PatchEmbed( patch_size=self.patch_size, in_chans=in_chans, @@ -1025,6 +1181,18 @@ def __init__( norm_layer=norm_layer, downsample=down_sample_mod, use_checkpoint=use_checkpoint, + use_hyena=self._per_stage_hyena[i_layer], + hyena_kernel_size=hyena_kernel_size, + hyena_kernel_mlp_dim=hyena_kernel_mlp_dim, + hyena_kernel_layers=hyena_kernel_layers, + hyena_mask_max_attenuation=hyena_mask_max_attenuation, + hyena_fft_padding=hyena_fft_padding, + hyena_grid_type=hyena_grid_type, + hyena_use_chunked_fft=hyena_use_chunked_fft, + hyena_use_fft_short_conv=hyena_use_fft_short_conv, + hyena_omega_0=hyena_omega_0, + hyena_l_cache=hyena_l_cache, + hyena_short_conv_fft_chunks=hyena_short_conv_fft_chunks, ) if i_layer == 0: self.layers1.append(layer) diff --git a/setup.cfg b/setup.cfg index d987141d0b..c025c685f4 100644 --- a/setup.cfg +++ b/setup.cfg @@ -132,6 +132,8 @@ pandas = pandas einops = einops +hyena = + nvsubquadratic>=0.1.1 transformers = transformers>=4.36.0, <4.41.0; python_version <= '3.10' mlflow = diff --git a/tests/min_tests.py b/tests/min_tests.py index 2d68f099a7..f98bf4b739 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -112,6 +112,8 @@ def run_testsuit(): "test_hausdorff_distance", "test_header_correct", "test_hilbert_transform", + "test_hyena_block", + "test_hyena_nd_unetr", "test_hovernet_loss", "test_image_dataset", "test_image_rw", diff --git a/tests/networks/blocks/test_hyena_block.py b/tests/networks/blocks/test_hyena_block.py new file mode 100644 index 0000000000..5f3835e2df --- /dev/null +++ b/tests/networks/blocks/test_hyena_block.py @@ -0,0 +1,336 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import math +import unittest +from unittest import skipUnless + +import torch +import torch.nn as nn +from parameterized import parameterized + +from monai.networks.blocks.hyena import ( + DepthwiseFFTConv2d, + DepthwiseFFTConv3d, + HyenaMixer, + HyenaTransformerBlock, + is_nvsubquadratic_available, +) + +HAS_NVSUBQ = is_nvsubquadratic_available() +HAS_CUDA = torch.cuda.is_available() + + +# --------------------------------------------------------------------------- +# DepthwiseFFTConv{2,3}d — no nvsubquadratic dependency +# --------------------------------------------------------------------------- + + +class TestDepthwiseFFTConvShape(unittest.TestCase): + """The FFT conv must preserve spatial dimensions for any depthwise config.""" + + @parameterized.expand( + [ + ("3d_k3_d16", (2, 8, 16, 16, 16), 3, 3), + ("3d_k1_d10", (1, 16, 10, 10, 10), 1, 3), + ("3d_k5_d12", (1, 8, 12, 12, 12), 5, 3), + ("2d_k3_d32", (2, 8, 32, 32), 3, 2), + ] + ) + def test_output_shape(self, _name, input_shape, kernel_size, spatial_dims): + channels = input_shape[1] + cls = DepthwiseFFTConv3d if spatial_dims == 3 else DepthwiseFFTConv2d + conv = cls(channels, channels, kernel_size=kernel_size, groups=channels, padding=kernel_size // 2) + x = torch.randn(*input_shape) + self.assertEqual(conv(x).shape, x.shape) + + +class TestDepthwiseFFTConvNumerics(unittest.TestCase): + """FFT conv must match the equivalent ``nn.Conv{2,3}d`` numerically.""" + + @parameterized.expand([("d8_s12", 8, 12), ("d16_s8", 16, 8), ("d32_s6", 32, 6)]) + def test_matches_conv3d(self, _name, channels, spatial): + ref = nn.Conv3d(channels, channels, kernel_size=3, groups=channels, padding=1, bias=False) + fft = DepthwiseFFTConv3d(channels, channels, kernel_size=3, groups=channels, padding=1) + with torch.no_grad(): + fft.weight.copy_(ref.weight) + x = torch.randn(2, channels, spatial, spatial, spatial) + with torch.no_grad(): + torch.testing.assert_close(fft(x), ref(x), atol=1e-4, rtol=1e-4) + + def test_matches_conv2d(self): + channels, spatial = 8, 16 + ref = nn.Conv2d(channels, channels, kernel_size=3, groups=channels, padding=1, bias=False) + fft = DepthwiseFFTConv2d(channels, channels, kernel_size=3, groups=channels, padding=1) + with torch.no_grad(): + fft.weight.copy_(ref.weight) + x = torch.randn(2, channels, spatial, spatial) + with torch.no_grad(): + torch.testing.assert_close(fft(x), ref(x), atol=1e-4, rtol=1e-4) + + +class TestDepthwiseFFTConvDtype(unittest.TestCase): + """Output dtype must match input dtype (AMP transparency).""" + + @parameterized.expand([("fp16", torch.float16), ("bf16", torch.bfloat16)]) + def test_amp_dtype_preserved(self, _name, dtype): + conv = DepthwiseFFTConv3d(8, 8, kernel_size=3, groups=8, padding=1) + x = torch.randn(1, 8, 8, 8, 8, dtype=dtype) + out = conv(x) + self.assertEqual(out.dtype, dtype) + self.assertEqual(out.shape, x.shape) + + def test_float32_preserved(self): + conv = DepthwiseFFTConv3d(8, 8, kernel_size=3, groups=8, padding=1) + x = torch.randn(1, 8, 8, 8, 8) + self.assertEqual(conv(x).dtype, torch.float32) + + +class TestDepthwiseFFTConvGradients(unittest.TestCase): + """Backward pass must produce gradients on both input and weight.""" + + def test_gradients_flow_3d(self): + conv = DepthwiseFFTConv3d(8, 8, kernel_size=3, groups=8, padding=1) + x = torch.randn(1, 8, 10, 10, 10, requires_grad=True) + conv(x).sum().backward() + self.assertIsNotNone(x.grad) + self.assertIsNotNone(conv.weight.grad) + + def test_gradients_flow_2d(self): + conv = DepthwiseFFTConv2d(8, 8, kernel_size=3, groups=8, padding=1) + x = torch.randn(1, 8, 16, 16, requires_grad=True) + conv(x).sum().backward() + self.assertIsNotNone(x.grad) + self.assertIsNotNone(conv.weight.grad) + + +class TestDepthwiseFFTConvConstruction(unittest.TestCase): + """Reject configurations the FFT path cannot represent.""" + + def test_rejects_non_depthwise(self): + with self.assertRaises(ValueError): + DepthwiseFFTConv3d(8, 8, kernel_size=3, groups=1, padding=1) + + def test_rejects_bias(self): + with self.assertRaises(ValueError): + DepthwiseFFTConv3d(8, 8, kernel_size=3, groups=8, padding=1, bias=True) + + def test_rejects_non_same_padding(self): + # forward() crops to the input size assuming padding == kernel_size // 2. + with self.assertRaisesRegex(ValueError, "same"): + DepthwiseFFTConv3d(8, 8, kernel_size=3, groups=8, padding=0) + + def test_rejects_even_kernel(self): + with self.assertRaisesRegex(ValueError, "same"): + DepthwiseFFTConv3d(8, 8, kernel_size=4, groups=8, padding=2) + + def test_weight_shape(self): + conv = DepthwiseFFTConv3d(16, 16, kernel_size=3, groups=16, padding=1) + self.assertEqual(conv.weight.shape, (16, 1, 3, 3, 3)) + + def test_weight_initialised(self): + conv = DepthwiseFFTConv3d(64, 64, kernel_size=3, groups=64, padding=1) + # kaiming_uniform with fan_in = 1 * 3^3 = 27 → bound ≈ 1/sqrt(27) ≈ 0.19 + self.assertGreater(conv.weight.abs().max().item(), 0.0) + self.assertLess(conv.weight.abs().max().item(), 5.0 / math.sqrt(27)) + + +# --------------------------------------------------------------------------- +# HyenaMixer configuration validation — no CUDA required (construction only, +# but nvsubquadratic must be present to reach the validation branch) +# --------------------------------------------------------------------------- + + +@skipUnless(HAS_NVSUBQ, "Requires nvsubquadratic") +class TestHyenaMixerConfigValidation(unittest.TestCase): + def test_rejects_circular_with_double_grid(self): + with self.assertRaisesRegex(ValueError, "circular.*single"): + HyenaMixer(dim=12, spatial_dims=3, fft_padding="circular", grid_type="double") + + def test_rejects_chunked_with_circular(self): + with self.assertRaisesRegex(ValueError, "chunked.*zero|zero.*chunked"): + HyenaMixer(dim=12, spatial_dims=3, fft_padding="circular", use_chunked_fftconv=True) + + def test_rejects_bad_fft_padding(self): + with self.assertRaisesRegex(ValueError, "fft_padding"): + HyenaMixer(dim=12, spatial_dims=3, fft_padding="reflective") + + def test_rejects_bad_grid_type(self): + with self.assertRaisesRegex(ValueError, "grid_type"): + HyenaMixer(dim=12, spatial_dims=3, grid_type="triple") + + def test_rejects_bad_spatial_dims(self): + with self.assertRaisesRegex(ValueError, "spatial_dims"): + HyenaMixer(dim=12, spatial_dims=4) + + def test_zero_double_chunked_constructs(self): + m = HyenaMixer(dim=12, spatial_dims=3, fft_padding="zero", grid_type="double", use_chunked_fftconv=True) + self.assertEqual(m.dim, 12) + + +class TestHyenaMixerOptionalDep(unittest.TestCase): + """When ``nvsubquadratic`` is missing, ``HyenaMixer`` must raise a clear ImportError.""" + + @skipUnless(not HAS_NVSUBQ, "Only runs when nvsubquadratic is absent") + def test_raises_import_error(self): + with self.assertRaisesRegex(ImportError, "nvsubquadratic"): + HyenaMixer(dim=12, spatial_dims=3) + + +# --------------------------------------------------------------------------- +# Forward shape — channels-last [B, *spatial, C] preserved +# --------------------------------------------------------------------------- + + +@skipUnless(HAS_NVSUBQ and HAS_CUDA, "Requires nvsubquadratic and CUDA") +class TestHyenaMixerForward(unittest.TestCase): + device = "cuda" + + def test_3d_forward_shape(self): + m = HyenaMixer(dim=12, spatial_dims=3).to(self.device) + x = torch.randn(2, 8, 8, 8, 12, device=self.device) + self.assertEqual(m(x).shape, x.shape) + + def test_2d_forward_shape(self): + m = HyenaMixer(dim=8, spatial_dims=2).to(self.device) + x = torch.randn(2, 16, 16, 8, device=self.device) + self.assertEqual(m(x).shape, x.shape) + + def test_zero_padding_forward(self): + m = HyenaMixer(dim=12, spatial_dims=3, fft_padding="zero", grid_type="single").to(self.device) + x = torch.randn(2, 8, 8, 8, 12, device=self.device) + self.assertEqual(m(x).shape, x.shape) + + def test_zero_double_chunked_forward(self): + m = HyenaMixer(dim=12, spatial_dims=3, fft_padding="zero", grid_type="double", use_chunked_fftconv=True).to( + self.device + ) + x = torch.randn(2, 8, 8, 8, 12, device=self.device) + self.assertEqual(m(x).shape, x.shape) + + +@skipUnless(HAS_NVSUBQ and HAS_CUDA, "Requires nvsubquadratic and CUDA") +class TestHyenaMixerGradients(unittest.TestCase): + device = "cuda" + + def test_qkv_and_out_proj_get_grads(self): + m = HyenaMixer(dim=12, spatial_dims=3).to(self.device) + x = torch.randn(2, 6, 6, 6, 12, device=self.device, requires_grad=True) + m(x).sum().backward() + self.assertIsNotNone(m.qkv_proj.weight.grad) + self.assertIsNotNone(m.out_proj.weight.grad) + self.assertIsNotNone(x.grad) + + def test_mixer_internal_params_get_grads(self): + m = HyenaMixer(dim=12, spatial_dims=3).to(self.device) + x = torch.randn(2, 6, 6, 6, 12, device=self.device) + m(x).sum().backward() + with_grad = [ + name for name, p in m.mixer.named_parameters() if p.grad is not None and p.grad.abs().sum().item() > 0 + ] + self.assertGreater(len(with_grad), 0, "no mixer-internal params received a gradient") + + +@skipUnless(HAS_NVSUBQ and HAS_CUDA, "Requires nvsubquadratic and CUDA") +class TestHyenaMixerAMP(unittest.TestCase): + """Under ``torch.autocast`` the output dtype must match the autocast dtype.""" + + device = "cuda" + + @parameterized.expand([("fp16", torch.float16), ("bf16", torch.bfloat16)]) + def test_autocast_output_dtype(self, _name, dtype): + m = HyenaMixer(dim=12, spatial_dims=3).to(self.device) + x = torch.randn(2, 6, 6, 6, 12, device=self.device) + with torch.autocast("cuda", dtype=dtype): + out = m(x) + self.assertEqual(out.dtype, dtype) + self.assertEqual(out.shape, x.shape) + + def test_float32_preserved(self): + m = HyenaMixer(dim=12, spatial_dims=3).to(self.device) + x = torch.randn(2, 6, 6, 6, 12, device=self.device) + self.assertEqual(m(x).dtype, torch.float32) + + +@skipUnless(HAS_NVSUBQ and HAS_CUDA, "Requires nvsubquadratic and CUDA") +class TestHyenaMixerDeterminism(unittest.TestCase): + device = "cuda" + + def test_same_seed_same_output(self): + torch.manual_seed(0) + m1 = HyenaMixer(dim=12, spatial_dims=3).to(self.device) + torch.manual_seed(0) + m2 = HyenaMixer(dim=12, spatial_dims=3).to(self.device) + x = torch.randn(1, 6, 6, 6, 12, device=self.device) + with torch.no_grad(): + y1, y2 = m1(x), m2(x) + torch.testing.assert_close(y1, y2, atol=0, rtol=0) + + +# --------------------------------------------------------------------------- +# HyenaTransformerBlock — full residual forward path +# --------------------------------------------------------------------------- + + +@skipUnless(HAS_NVSUBQ and HAS_CUDA, "Requires nvsubquadratic and CUDA") +class TestHyenaTransformerBlock(unittest.TestCase): + device = "cuda" + + def test_3d_forward_shape(self): + blk = HyenaTransformerBlock(dim=12, spatial_dims=3).to(self.device) + x = torch.randn(2, 6, 6, 6, 12, device=self.device) + self.assertEqual(blk(x).shape, x.shape) + + def test_2d_forward_shape(self): + blk = HyenaTransformerBlock(dim=8, spatial_dims=2).to(self.device) + x = torch.randn(2, 16, 16, 8, device=self.device) + self.assertEqual(blk(x).shape, x.shape) + + def test_grad_flow_through_block(self): + blk = HyenaTransformerBlock(dim=12, spatial_dims=3).to(self.device) + x = torch.randn(2, 6, 6, 6, 12, device=self.device) + blk(x).sum().backward() + self.assertIsNotNone(blk.mixer.qkv_proj.weight.grad) + self.assertIsNotNone(blk.mixer.out_proj.weight.grad) + mlp_params_with_grad = [p for p in blk.mlp.parameters() if p.grad is not None] + self.assertGreater(len(mlp_params_with_grad), 0) + + def test_mask_matrix_accepted_and_ignored(self): + """``mask_matrix`` is accepted (signature parity with Swin) but ignored.""" + blk = HyenaTransformerBlock(dim=12, spatial_dims=3).to(self.device) + x = torch.randn(2, 6, 6, 6, 12, device=self.device) + with torch.no_grad(): + y1 = blk(x) + y2 = blk(x, mask_matrix=torch.ones(1, device=self.device)) + torch.testing.assert_close(y1, y2) + + +@skipUnless(HAS_NVSUBQ and HAS_CUDA, "Requires nvsubquadratic and CUDA") +class TestHyenaMixerFFTShortConv(unittest.TestCase): + """The use_fft_short_conv=True path swaps Conv3d for DepthwiseFFTConv3d.""" + + device = "cuda" + + def test_3d_constructs_and_runs(self): + m = HyenaMixer(dim=12, spatial_dims=3, use_fft_short_conv=True).to(self.device) + x = torch.randn(2, 8, 8, 8, 12, device=self.device) + self.assertEqual(m(x).shape, x.shape) + + def test_3d_with_short_conv_chunks(self): + m = HyenaMixer(dim=12, spatial_dims=3, use_fft_short_conv=True, short_conv_fft_chunk_size=4).to(self.device) + x = torch.randn(2, 8, 8, 8, 12, device=self.device) + self.assertEqual(m(x).shape, x.shape) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/networks/nets/test_hyena_nd_unetr.py b/tests/networks/nets/test_hyena_nd_unetr.py new file mode 100644 index 0000000000..4fdb7356f1 --- /dev/null +++ b/tests/networks/nets/test_hyena_nd_unetr.py @@ -0,0 +1,137 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest +from unittest import skipUnless + +import torch +from parameterized import parameterized + +from monai.networks.blocks.hyena import HyenaTransformerBlock, is_nvsubquadratic_available +from monai.networks.nets.hyena_nd_unetr import PAPER_VARIANTS, HyenaNDUNETR +from monai.networks.nets.swin_unetr import SwinTransformerBlock, SwinUNETR +from tests.test_utils import skip_if_no_cuda + +HAS_NVSUBQ = is_nvsubquadratic_available() + + +PAPER_VARIANT_CASES = [ + ("HHHH", (True, True, True, True)), + ("HAHA", (True, False, True, False)), + ("HHAA", (True, True, False, False)), +] + + +def _block_type_at_stage(model, stage_idx): + layer_attr = ["layers1", "layers2", "layers3", "layers4"][stage_idx] + return type(getattr(model.swinViT, layer_attr)[0].blocks[0]) + + +class TestHyenaNDUNETRConstructorContract(unittest.TestCase): + """``HyenaNDUNETR.__init__`` enforces an explicit, non-empty ``hyena_stages``.""" + + @skipUnless(HAS_NVSUBQ, "Requires nvsubquadratic") + def test_explicit_stages_required(self): + with self.assertRaisesRegex(ValueError, "requires `hyena_stages`"): + HyenaNDUNETR(in_channels=1, out_channels=14, feature_size=12, hyena_stages=None) + + @skipUnless(HAS_NVSUBQ, "Requires nvsubquadratic") + def test_wrong_length_stages_rejected(self): + with self.assertRaisesRegex(ValueError, "length 4"): + HyenaNDUNETR(in_channels=1, out_channels=14, feature_size=12, hyena_stages=(True, True)) + + @skipUnless(HAS_NVSUBQ, "Requires nvsubquadratic") + def test_all_false_stages_rejected(self): + with self.assertRaisesRegex(ValueError, "at least one stage"): + HyenaNDUNETR(in_channels=1, out_channels=14, feature_size=12, hyena_stages=(False, False, False, False)) + + @skipUnless(HAS_NVSUBQ, "Requires nvsubquadratic") + def test_use_hyena_kwarg_rejected(self): + """The subclass forces use_hyena=True; caller may not override via kwargs.""" + with self.assertRaisesRegex(TypeError, "use_hyena"): + HyenaNDUNETR( + in_channels=1, out_channels=14, feature_size=12, hyena_stages=(True, True, False, False), use_hyena=True + ) + + @skipUnless(HAS_NVSUBQ, "Requires nvsubquadratic") + def test_subclass_of_swin_unetr(self): + m = HyenaNDUNETR(in_channels=1, out_channels=14, feature_size=12, hyena_stages=(True, True, False, False)) + self.assertIsInstance(m, SwinUNETR) + # The forced kwargs land on the instance via SwinUNETR.__init__. + self.assertTrue(m.use_hyena) + self.assertEqual(m.hyena_stages, (True, True, False, False)) + + +class TestHyenaNDUNETRFromPaperVariant(unittest.TestCase): + """``get_variant`` maps {HHHH, HAHA, HHAA} to the correct stage pattern.""" + + @parameterized.expand(PAPER_VARIANT_CASES) + @skipUnless(HAS_NVSUBQ, "Requires nvsubquadratic") + def test_returns_expected_stages(self, name, expected_stages): + m = HyenaNDUNETR.get_variant(name, in_channels=1, out_channels=14, feature_size=12) + self.assertEqual(m.hyena_stages, expected_stages) + for stage_idx, want_hyena in enumerate(expected_stages): + block_type = _block_type_at_stage(m, stage_idx) + if want_hyena: + self.assertIs(block_type, HyenaTransformerBlock) + else: + self.assertIs(block_type, SwinTransformerBlock) + + @skipUnless(HAS_NVSUBQ, "Requires nvsubquadratic") + def test_case_insensitive(self): + m_upper = HyenaNDUNETR.get_variant("HHAA", in_channels=1, out_channels=14, feature_size=12) + m_lower = HyenaNDUNETR.get_variant("hhaa", in_channels=1, out_channels=14, feature_size=12) + self.assertEqual(m_upper.hyena_stages, m_lower.hyena_stages) + + @skipUnless(HAS_NVSUBQ, "Requires nvsubquadratic") + def test_aaaa_rejected(self): + """AAAA is plain SwinUNETR and intentionally not exposed via this constructor.""" + with self.assertRaisesRegex(ValueError, "Unknown paper variant"): + HyenaNDUNETR.get_variant("AAAA", in_channels=1, out_channels=14, feature_size=12) + + @skipUnless(HAS_NVSUBQ, "Requires nvsubquadratic") + def test_unknown_variant_rejected(self): + with self.assertRaisesRegex(ValueError, "Unknown paper variant"): + HyenaNDUNETR.get_variant("HAAA", in_channels=1, out_channels=14, feature_size=12) + + @skipUnless(HAS_NVSUBQ, "Requires nvsubquadratic") + def test_redundant_hyena_stages_kwarg_rejected(self): + with self.assertRaisesRegex(ValueError, "do not also pass hyena_stages"): + HyenaNDUNETR.get_variant( + "HHAA", in_channels=1, out_channels=14, feature_size=12, hyena_stages=(True, False, True, False) + ) + + @skipUnless(HAS_NVSUBQ, "Requires nvsubquadratic") + def test_paper_variants_table_matches_constants(self): + """Guard against the table in PAPER_VARIANTS drifting.""" + self.assertEqual(PAPER_VARIANTS["HHHH"], (True, True, True, True)) + self.assertEqual(PAPER_VARIANTS["HAHA"], (True, False, True, False)) + self.assertEqual(PAPER_VARIANTS["HHAA"], (True, True, False, False)) + + +class TestHyenaNDUNETRForward(unittest.TestCase): + """End-to-end forward over the three paper variants. CUDA required.""" + + @parameterized.expand(PAPER_VARIANT_CASES) + @skipUnless(HAS_NVSUBQ, "Requires nvsubquadratic") + @skip_if_no_cuda + def test_forward_shape(self, name, _stages): + m = HyenaNDUNETR.get_variant(name, in_channels=1, out_channels=14, feature_size=12).cuda().eval() + x = torch.randn(1, 1, 64, 64, 64, device="cuda") + with torch.no_grad(): + out = m(x) + self.assertEqual(out.shape, (1, 14, 64, 64, 64)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/networks/nets/test_swin_unetr.py b/tests/networks/nets/test_swin_unetr.py index ba94aab4f9..80b627f194 100644 --- a/tests/networks/nets/test_swin_unetr.py +++ b/tests/networks/nets/test_swin_unetr.py @@ -21,7 +21,15 @@ from monai.apps import download_url from monai.networks import eval_mode -from monai.networks.nets.swin_unetr import PatchMerging, PatchMergingV2, SwinUNETR, filter_swinunetr +from monai.networks.blocks.hyena import HyenaTransformerBlock, is_nvsubquadratic_available +from monai.networks.nets.swin_unetr import ( + PatchMerging, + PatchMergingV2, + SwinTransformer, + SwinTransformerBlock, + SwinUNETR, + filter_swinunetr, +) from monai.networks.utils import copy_model_state from monai.utils import optional_import from tests.test_utils import ( @@ -34,6 +42,8 @@ ) einops, has_einops = optional_import("einops") +HAS_NVSUBQ = is_nvsubquadratic_available() +HAS_CUDA = torch.cuda.is_available() test_merging_mode = ["mergingv2", "merging", PatchMerging, PatchMergingV2] checkpoint_vals = [True, False] @@ -126,5 +136,202 @@ def test_filter_swinunetr(self, input_param, key, value): self.assertTrue(len(loaded) == 157 and len(not_loaded) == 2) +# Backward-compat reference for SwinUNETR(use_hyena=False), feature_size=12, img_size=64^3, +# seeds (model=0, input=1), CPU. Captured before the HyenaND port; the default code path must +# keep reproducing this within tolerance. Tolerance-based (not a byte hash) so it tolerates +# benign cross-platform float drift while still catching a real change to the non-Hyena path. +HYENA_BACKCOMPAT_REF = torch.tensor( + [ + -0.069162, + -0.209673, + 0.543457, + -0.111868, + 0.474825, + 0.031108, + 0.191482, + -0.167401, + 0.091668, + 0.272223, + -0.084950, + -0.042126, + ] +) + + +def _build_hyena_unetr(use_hyena=False, hyena_stages=None, feature_size=12, out_channels=14): + return SwinUNETR( + in_channels=1, + out_channels=out_channels, + feature_size=feature_size, + use_hyena=use_hyena, + hyena_stages=hyena_stages, + ) + + +def _block_type_at_stage(model, stage_idx): + layer_attr = ["layers1", "layers2", "layers3", "layers4"][stage_idx] + return type(getattr(model.swinViT, layer_attr)[0].blocks[0]) + + +HYENA_VARIANT_CASES = [ + ("AAAA", False, None), + ("HHHH", True, None), + ("HAHA", True, (True, False, True, False)), + ("HHAA", True, (True, True, False, False)), +] + + +class TestSwinUNETRHyenaBackCompat(unittest.TestCase): + """The non-Hyena code path must keep reproducing its pre-port output (within tolerance).""" + + @skipUnless(has_einops, "Requires einops") + def test_default_path_unchanged(self): + """SwinUNETR with no hyena kwargs reproduces the pre-port reference output. + + Runs on CPU so it executes in environments without a GPU and is stable across + platforms; ``assert_close`` tolerates benign float drift while still flagging a real + change to the default (non-Hyena) code path. + """ + torch.manual_seed(0) + net = SwinUNETR(in_channels=1, out_channels=14, feature_size=12).eval() + torch.manual_seed(1) + x = torch.randn(1, 1, 64, 64, 64) + with torch.no_grad(): + out = net(x) + self.assertEqual(out.shape, (1, 14, 64, 64, 64)) + assert_allclose( + out.flatten()[: HYENA_BACKCOMPAT_REF.numel()], HYENA_BACKCOMPAT_REF, atol=1e-4, rtol=1e-4, type_test=False + ) + + +class TestSwinUNETRHyenaStages(unittest.TestCase): + """``hyena_stages`` must place :class:`HyenaTransformerBlock` at flagged stages and + :class:`SwinTransformerBlock` everywhere else. Construction-only; no CUDA required.""" + + @skipUnless(HAS_NVSUBQ, "Requires nvsubquadratic") + def test_haha_pattern(self): + m = _build_hyena_unetr(use_hyena=True, hyena_stages=(True, False, True, False)) + self.assertIs(_block_type_at_stage(m, 0), HyenaTransformerBlock) + self.assertIs(_block_type_at_stage(m, 1), SwinTransformerBlock) + self.assertIs(_block_type_at_stage(m, 2), HyenaTransformerBlock) + self.assertIs(_block_type_at_stage(m, 3), SwinTransformerBlock) + + @skipUnless(HAS_NVSUBQ, "Requires nvsubquadratic") + def test_hhaa_pattern(self): + m = _build_hyena_unetr(use_hyena=True, hyena_stages=(True, True, False, False)) + self.assertIs(_block_type_at_stage(m, 0), HyenaTransformerBlock) + self.assertIs(_block_type_at_stage(m, 1), HyenaTransformerBlock) + self.assertIs(_block_type_at_stage(m, 2), SwinTransformerBlock) + self.assertIs(_block_type_at_stage(m, 3), SwinTransformerBlock) + + def test_aaaa_pattern_default(self): + m = _build_hyena_unetr(use_hyena=False) + for i in range(4): + self.assertIs(_block_type_at_stage(m, i), SwinTransformerBlock) + + @skipUnless(HAS_NVSUBQ, "Requires nvsubquadratic") + def test_hhhh_pattern_default(self): + m = _build_hyena_unetr(use_hyena=True) + for i in range(4): + self.assertIs(_block_type_at_stage(m, i), HyenaTransformerBlock) + + @skipUnless(HAS_NVSUBQ, "Requires nvsubquadratic") + def test_wrong_length_hyena_stages_raises(self): + with self.assertRaisesRegex(ValueError, "hyena_stages must have length"): + _build_hyena_unetr(use_hyena=True, hyena_stages=(True, True)) + + +class TestSwinUNETRHyenaForward(unittest.TestCase): + """Forward shape across the four paper variants. CUDA required.""" + + @parameterized.expand(HYENA_VARIANT_CASES) + @skipUnless(HAS_NVSUBQ, "Requires nvsubquadratic") + @skip_if_no_cuda + def test_forward_shape(self, _name, use_hyena, hyena_stages): + m = _build_hyena_unetr(use_hyena=use_hyena, hyena_stages=hyena_stages).cuda() + x = torch.randn(1, 1, 64, 64, 64, device="cuda") + with torch.no_grad(): + out = m(x) + self.assertEqual(out.shape, (1, 14, 64, 64, 64)) + + +class TestSwinUNETRHyenaGradient(unittest.TestCase): + """Backward through the HHAA variant must produce grads on at least 90 percent of params.""" + + @skipUnless(HAS_NVSUBQ, "Requires nvsubquadratic") + @skip_if_no_cuda + def test_hhaa_backward(self): + m = _build_hyena_unetr(use_hyena=True, hyena_stages=(True, True, False, False)).cuda() + x = torch.randn(1, 1, 64, 64, 64, device="cuda") + m(x).sum().backward() + total = list(m.parameters()) + with_grad = [p for p in total if p.grad is not None] + coverage = len(with_grad) / len(total) + self.assertGreater(coverage, 0.9, f"only {coverage:.1%} of params received gradients") + + +class TestSwinTransformerRoPEDivisibility(unittest.TestCase): + """3D Hyena requires embed_dim * 2^layer % 6 == 0; 2D requires % 4.""" + + @skipUnless(HAS_NVSUBQ, "Requires nvsubquadratic") + def test_3d_rejects_non_divisible_embed_dim(self): + with self.assertRaisesRegex(ValueError, "divisible by 6"): + SwinTransformer( + in_chans=1, + embed_dim=14, + window_size=(2, 2, 2), + patch_size=(2, 2, 2), + depths=(2, 2, 2, 2), + num_heads=(3, 6, 12, 24), + spatial_dims=3, + use_hyena=True, + ) + + @skipUnless(HAS_NVSUBQ, "Requires nvsubquadratic") + def test_2d_rejects_non_divisible_embed_dim(self): + with self.assertRaisesRegex(ValueError, "divisible by 4"): + SwinTransformer( + in_chans=1, + embed_dim=14, + window_size=(2, 2), + patch_size=(2, 2), + depths=(2, 2, 2, 2), + num_heads=(3, 6, 12, 24), + spatial_dims=2, + use_hyena=True, + ) + + @skipUnless(HAS_NVSUBQ, "Requires nvsubquadratic") + def test_per_stage_skips_check_for_attention_stages(self): + """Per-stage False suppresses the check for that stage; remaining Hyena stages still fire.""" + with self.assertRaisesRegex(ValueError, "divisible by 6"): + SwinTransformer( + in_chans=1, + embed_dim=14, + window_size=(2, 2, 2), + patch_size=(2, 2, 2), + depths=(2, 2, 2, 2), + num_heads=(3, 6, 12, 24), + spatial_dims=3, + use_hyena=True, + hyena_stages=(False, True, False, False), + ) + + +class TestSwinUNETRHyenaSlidingWindow(unittest.TestCase): + """The production inference path: sliding-window inference over HHAA must succeed.""" + + @skipUnless(HAS_NVSUBQ, "Requires nvsubquadratic") + @skip_if_no_cuda + def test_swi_hhaa(self): + from monai.inferers import sliding_window_inference + + m = _build_hyena_unetr(use_hyena=True, hyena_stages=(True, True, False, False)).cuda().eval() + x = torch.randn(1, 1, 96, 96, 96, device="cuda") + with torch.no_grad(): + out = sliding_window_inference(inputs=x, roi_size=(64, 64, 64), sw_batch_size=2, predictor=m, overlap=0.25) + self.assertEqual(out.shape, (1, 14, 96, 96, 96)) + + if __name__ == "__main__": unittest.main()