Skip to content

Add Graph-safe NVFP4 CUTLASS Group GEMM to unified entry points#3175

Open
cael-ling wants to merge 5 commits into
NVIDIA:mainfrom
cael-ling:optimize/group-gemm-grouped-tensor
Open

Add Graph-safe NVFP4 CUTLASS Group GEMM to unified entry points#3175
cael-ling wants to merge 5 commits into
NVIDIA:mainfrom
cael-ling:optimize/group-gemm-grouped-tensor

Conversation

@cael-ling

@cael-ling cael-ling commented Jul 4, 2026

Copy link
Copy Markdown
Contributor

Description

Follow-up to #3134.

Review feedback (from @vthumbe1503) was that TE already has grouped-GEMM entry points (cuBLAS single-launch, multi-stream cuBLAS, cuDSL single-launch, and now CUTLASS), and that the CUTLASS kernel should instead live behind the single grouped-tensor API. This PR migrates the kernel from #3134 accordingly:

  • CUTLASS is now an env-selectable backend inside general_grouped_gemm_for_grouped_tensor (the graph-safe grouped-tensor path), still gated by NVTE_NVFP4_CUTLASS_GROUPED_GEMM. NVTE_GROUPED_LINEAR_USE_FUSED_GROUPED_GEMM=1 routes GroupedLinear onto that path.
  • It consumes the same device-side setup metadata as the cuBLAS single-launch path, so the CUTLASS path is now CUDA-graph safe.
  • Benchmarked head-to-head against the cuBLAS single-launch path (not multi-stream now), both eager per-GEMM and CUDA-graphed whole-step.

Default behavior is unchanged (backend is opt-in). #3134 is temporarily retained for comparison purposes.

Performance

SM100 (Blackwell), NVFP4 default recipe. cuBLAS single-launch vs CUTLASS single-launch, both through general_grouped_gemm_for_grouped_tensor. Times in ms/iter (lower is better), iters=30. Reproduce with:

python benchmarks/linear/benchmark_grouped_linear.py --compare-backends # eager, per-GEMM python benchmarks/linear/benchmark_grouped_linear.py --compare-backends --graph # CUDA-graphed whole step

Eager, isolated per-GEMM (cuBLAS / CUTLASS / speedup)

experts tokens hidden out dist fprop dgrad wgrad
8 8192 4096 4096 balanced 0.824/0.793/1.04x 0.502/0.481/1.04x 0.866/0.530/1.63x
8 8192 4096 4096 imbalanced 0.853/0.826/1.03x 0.495/0.482/1.03x 0.865/0.532/1.63x
8 16384 7168 2048 balanced 0.836/0.813/1.03x 0.499/0.483/1.03x 0.796/0.508/1.57x
8 16384 7168 2048 imbalanced 0.860/0.800/1.07x 0.511/0.487/1.05x 0.790/0.509/1.55x
16 16384 4096 4096 balanced 1.271/1.215/1.05x 0.567/0.544/1.04x 1.427/0.806/1.77x
16 16384 4096 4096 imbalanced 1.265/1.233/1.03x 0.555/0.526/1.05x 1.422/0.814/1.75x
32 16384 2048 2048 balanced 2.056/2.014/1.02x 0.670/0.651/1.03x 0.909/0.556/1.63x
32 16384 2048 2048 imbalanced 2.073/2.038/1.02x 0.664/0.644/1.03x 0.916/0.560/1.63x
8 65536 7168 2048 balanced 1.065/1.118/0.95x 0.702/0.824/0.85x 1.205/0.799/1.51x
8 65536 7168 2048 imbalanced 1.056/1.124/0.94x 0.704/0.822/0.86x 1.194/0.823/1.45x

CUDA-graphed whole-step fwd+bwd (cuBLAS / CUTLASS / speedup)

experts tokens hidden out dist fwd+bwd
8 8192 4096 4096 balanced 1.218/0.928/1.31x
8 8192 4096 4096 imbalanced 1.221/0.920/1.33x
8 16384 7168 2048 balanced 1.284/1.054/1.22x
8 16384 7168 2048 imbalanced 1.289/1.058/1.22x
16 16384 4096 4096 balanced 2.313/1.737/1.33x
16 16384 4096 4096 imbalanced 2.320/1.732/1.34x
32 16384 2048 2048 balanced 1.472/1.155/1.27x
32 16384 2048 2048 imbalanced 1.494/1.146/1.30x
8 65536 7168 2048 balanced 2.689/2.485/1.08x
8 65536 7168 2048 imbalanced 2.688/2.489/1.08x
Summary: wgrad is the main win (~1.45–1.77x). fprop/dgrad are ~parity and
slightly negative at very large M (65536). End-to-end CUDA-graphed step is
~1.08–1.34x over single launch cuBLASLt backend

Type of change

  • New feature (non-breaking change which adds functionality)
  • Code refactoring

Changes

Common (CUDA/C++)

  • Route the per-tensor NVFP4 CUTLASS kernel through
    general_grouped_gemm_for_grouped_tensor / execute_grouped_gemm, consuming
    the device-side GroupedGemmSetupWorkspace arrays (graph-safe), selected by
    NVTE_NVFP4_CUTLASS_GROUPED_GEMM; fall back to cuBLAS when unsupported.
  • Replace the host-side vector launcher run_grouped_per_tensor_gemm (and its
    persistent host buffer) with run_grouped_per_tensor_gemm_grouped_tensor.
  • Remove the old multi-stream CUTLASS dispatch branch and the two bench-only
    entry points.
    PyTorch
  • Remove the two bench-only bindings.
  • tests/pytorch/test_grouped_linear.py: grouped-tensor backend parity
    (test_nvfp4_grouped_tensor_cutlass_matches_cublas) + CUDA-graph safety
    (test_nvfp4_grouped_tensor_cutlass_cuda_graph_safe); remove the dead
    multi-stream CUTLASS tests.
  • benchmarks/linear/benchmark_grouped_linear.py: --compare-backends
    (eager per-GEMM) and --compare-backends --graph (CUDA-graphed whole-step).

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes (full
    tests/pytorch/test_grouped_linear.py passes on SM100)

Replace the per-expert multi-stream cuBLASLt loop in the per-tensor NVFP4
grouped path with one CUTLASS ptr-array grouped launch on SM100 (Blackwell).

Common:
- Add nvfp4_cutlass_grouped_gemm.{cuh,cu}: a single-launch per-tensor NVFP4
  grouped kernel. Covers BF16 output (fprop/dgrad, overwrite), FP32 output
  (wgrad, with optional in-place accumulate for Megatron wgrad fusion), and
  optional fused per-group bias (fprop). Per-tensor scaling collapses the
  second-level scale to one fp32 alpha per group, applied via the epilogue's
  per-group alpha_ptr_array, Arch-gated on CUTLASS_ARCH_MMA_SM100_SUPPORTED.
- Wire it into nvte_multi_tensor_gemm behind the opt-in env
  NVTE_NVFP4_CUTLASS_GROUPED_GEMM. M/N/K must be %128; empty (0-token)
  experts schedule 0 tiles and no longer force a multi-stream fallback.
  Anything unsupported falls through to the existing cuBLAS path, so default
  behavior is unchanged.
- Add bench-only entry points nvte_nvfp4_grouped_per_tensor_compute_alpha /
  nvte_nvfp4_grouped_per_tensor_gemm so a benchmark can precompute alpha
  outside the timed region and time only the grouped GEMM launch.

PyTorch:
- Expose the two bench-only bindings (tex.nvfp4_grouped_per_tensor_compute_alpha
  and tex.nvfp4_grouped_per_tensor_gemm). Not used by the production dispatch.
- Extend test_grouped_linear.py with NVFP4 cutlass-vs-multistream parity tests:
  GEMM-level (uniform + uneven 128-aligned splits), empty groups, and
  end-to-end GroupedLinear fwd+bwd (bias, fuse_wgrad_accumulation).
- Add a GEMM-level cutlass-vs-multistream comparison to
  benchmark_grouped_linear.py (--compare-nvfp4-grouped-gemm): a DISPATCH row
  (both backends via the shared dispatch) and a fair PURE row (operands
  pre-swizzled, alpha precomputed; times only the GEMM).

Signed-off-by: Cael Ling <caell@nvidia.com>
…uped-tensor path

Follow-up to NVIDIA#3134, which added the single-launch CUTLASS ptr-array grouped
kernel for per-tensor NVFP4 but wired it into the per-expert multi-stream
nvte_multi_tensor_gemm path. Per review feedback about grouped-GEMM entry-point
fragmentation, move the CUTLASS kernel behind
general_grouped_gemm_for_grouped_tensor as an backend and
benchmark it head-to-head against the cuBLAS single-launch path (not
multi-stream). Consuming the device-side grouped-tensor setup metadata also
makes the CUTLASS path CUDA-graph safe.

Common:
- Route the CUTLASS kernel through the grouped-tensor GEMM
  (general_grouped_gemm_for_grouped_tensor / execute_grouped_gemm). It consumes
  the same device-side GroupedGemmSetupWorkspace arrays as the cuBLAS
  single-launch path, inheriting its graph-safety; selected by
  NVTE_NVFP4_CUTLASS_GROUPED_GEMM, falling back to cuBLAS when unsupported.
- Replace the host-side vector launcher (run_grouped_per_tensor_gemm) and its
  persistent host buffer with the device-metadata entry point
  run_grouped_per_tensor_gemm_grouped_tensor.
- Remove the old multi-stream CUTLASS dispatch branch and the two bench-only
  entry points (nvte_nvfp4_grouped_per_tensor_compute_alpha /
  nvte_nvfp4_grouped_per_tensor_gemm) and their declarations.

PyTorch:
- Drop the two bench-only bindings.
- test_grouped_linear.py: replace the now-dead multi-stream CUTLASS parity
  tests with grouped-tensor backend tests --
  test_nvfp4_grouped_tensor_cutlass_matches_cublas (fprop/dgrad/wgrad parity
  across balanced/imbalanced splits and fuse_wgrad_accumulation) and
  test_nvfp4_grouped_tensor_cutlass_cuda_graph_safe.
- benchmark_grouped_linear.py: replace --compare-nvfp4-grouped-gemm with
  --compare-backends (eager per-GEMM) and --compare-backends --graph
  (CUDA-graphed whole-step) CUTLASS-vs-cuBLAS single-launch comparison.

Signed-off-by: Cael Ling <caell@nvidia.com>
@cael-ling cael-ling requested a review from ksivaman as a code owner July 4, 2026 09:23
@github-actions github-actions Bot added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label Jul 4, 2026
@greptile-apps

greptile-apps Bot commented Jul 4, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR migrates the NVFP4 CUTLASS single-launch grouped GEMM kernel from standalone bench-only entry points into the existing graph-safe grouped-tensor path (general_grouped_gemm_for_grouped_tensor), gated by NVTE_NVFP4_CUTLASS_GROUPED_GEMM. The persistent device buffer is now correctly keyed by device ID (using current_device()), addressing the multi-GPU concerns from the prior review, and all metadata is built on-device via a small setup kernel so the launch is CUDA-graph capturable with no host↔device sync.

  • New nvfp4_cutlass_grouped_gemm.cu: Implements a device-side metadata builder kernel (build_grouped_metadata_kernel) and a persistent_buffer cache (per-device, two slots) that keeps CUTLASS scratch and workspace alive across launches; run_impl_device<ElementOutT> dispatches to either a BF16-output (fprop/dgrad) or FP32-output (wgrad, with per-group beta accumulation) kernel.
  • cublaslt_grouped_gemm.cu: Adds maybe_run_nvfp4_cutlass_grouped as an inline gate at all three grouped-GEMM dispatch sites; falls back to cuBLAS when the env var is unset, the SM is not 100, or the output type doesn't qualify.
  • Tests / benchmark: Two new parity tests (backend equivalence and CUDA-graph safety) and a --compare-backends flag in the benchmark that sweeps MoE shapes across cuBLAS vs CUTLASS.

Confidence Score: 5/5

Safe to merge. The CUTLASS backend is strictly opt-in, default behaviour is unchanged, and the new code shares the same device-side setup workspace as the existing cuBLAS single-launch path — inheriting its graph safety by construction.

The migration is well-scoped: the new CUTLASS path is off by default and reachable only through an explicit env var + SM100 + NVFP4 guard. The previous review's multi-GPU concerns (hardcoded device 0, device-unaware persistent buffer) have both been addressed — device ID is now resolved via current_device() and the buffer is keyed per device. The two new tests (backend parity and CUDA-graph safety) exercise the critical paths on real hardware. The only items flagged are defensive coding gaps that do not affect correctness under the current call-site invariants.

transformer_engine/common/gemm/cublaslt_grouped_gemm.cu (the maybe_run_nvfp4_cutlass_grouped guard) and transformer_engine/common/gemm/nvfp4_cutlass_grouped_gemm.cu (the CUTLASS initialize contract during graph capture) are worth a second look.

Important Files Changed

Filename Overview
transformer_engine/common/gemm/nvfp4_cutlass_grouped_gemm.cu New CUTLASS NVFP4 per-tensor grouped GEMM implementation; device-keyed persistent buffer and on-device metadata builder kernel are correct; minor defensive gap in the B operand scaling-mode guard in the caller.
transformer_engine/common/gemm/cublaslt_grouped_gemm.cu Adds maybe_run_nvfp4_cutlass_grouped at three dispatch sites; SM and dtype guards are correct; the A_sel-only scaling-mode check is a minor defensive gap since B_sel is not independently validated in this function.
transformer_engine/common/gemm/nvfp4_cutlass_grouped_gemm.cuh Clean header declaring the single public entry point with thorough documentation of the cuBLAS→CUTLASS A/B swap convention and accumulate semantics.
tests/pytorch/test_grouped_linear.py Adds parity (cuBLAS vs CUTLASS) and CUDA-graph-safety tests for the NVFP4 grouped-tensor path; tolerances are intentionally loose to accommodate reduction-order differences.
benchmarks/linear/benchmark_grouped_linear.py Adds --compare-backends (eager per-GEMM) and --compare-backends --graph (CUDA-graphed fwd+bwd) sweep modes; FP8 state is correctly reset between backends; no issues found.
transformer_engine/common/CMakeLists.txt Correctly registers nvfp4_cutlass_grouped_gemm.cu in both the arch-specific source list and the CUTLASS_KERNEL_SOURCES debug-build guard list.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["nvte_grouped_gemm / nvte_grouped_gemm_with_discrete_inputA\n/ nvte_grouped_gemm_with_discrete_out"] --> B["launch_grouped_gemm_setup\n(populates GroupedGemmSetupWorkspace on device)"]
    B --> C["maybe_run_nvfp4_cutlass_grouped"]
    C --> D{NVTE_NVFP4_CUTLASS_GROUPED_GEMM == 1?}
    D -- No --> G["execute_grouped_gemm\n(cuBLAS single-launch)"]
    D -- Yes --> E{A_sel == NVFP4\nSM100\nBF16 or FP32 out?}
    E -- No --> G
    E -- Yes --> F["run_grouped_per_tensor_gemm_grouped_tensor\n(CUTLASS path)"]
    F --> H["persistent_buffer\n(per-device scratch + workspace)"]
    H --> I["build_grouped_metadata_kernel\n(device: fills problems_d, strides, SFA/SFB layouts)"]
    I --> J{fp32_output?}
    J -- Yes --> K["run_impl_device<float>\n(wgrad / per-group beta)"]
    J -- No --> L["run_impl_device<bfloat16_t>\n(fprop / dgrad, overwrite)"]
    K --> M["Gemm::initialize + Gemm::run\n(CUTLASS single-launch, CUDA-graph safe)"]
    L --> M
    G --> N["return to caller"]
    M --> N
Loading
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
flowchart TD
    A["nvte_grouped_gemm / nvte_grouped_gemm_with_discrete_inputA\n/ nvte_grouped_gemm_with_discrete_out"] --> B["launch_grouped_gemm_setup\n(populates GroupedGemmSetupWorkspace on device)"]
    B --> C["maybe_run_nvfp4_cutlass_grouped"]
    C --> D{NVTE_NVFP4_CUTLASS_GROUPED_GEMM == 1?}
    D -- No --> G["execute_grouped_gemm\n(cuBLAS single-launch)"]
    D -- Yes --> E{A_sel == NVFP4\nSM100\nBF16 or FP32 out?}
    E -- No --> G
    E -- Yes --> F["run_grouped_per_tensor_gemm_grouped_tensor\n(CUTLASS path)"]
    F --> H["persistent_buffer\n(per-device scratch + workspace)"]
    H --> I["build_grouped_metadata_kernel\n(device: fills problems_d, strides, SFA/SFB layouts)"]
    I --> J{fp32_output?}
    J -- Yes --> K["run_impl_device<float>\n(wgrad / per-group beta)"]
    J -- No --> L["run_impl_device<bfloat16_t>\n(fprop / dgrad, overwrite)"]
    K --> M["Gemm::initialize + Gemm::run\n(CUTLASS single-launch, CUDA-graph safe)"]
    L --> M
    G --> N["return to caller"]
    M --> N
Loading

Reviews (2): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

Comment thread transformer_engine/common/gemm/nvfp4_cutlass_grouped_gemm.cu Outdated
Comment thread transformer_engine/common/gemm/nvfp4_cutlass_grouped_gemm.cu
Comment thread tests/pytorch/test_grouped_linear.py Outdated
Comment thread tests/pytorch/test_grouped_linear.py Outdated
cael-ling and others added 3 commits July 4, 2026 02:51
…-grouped-tensor

Resolve grouped_linear.py: upstream independently routes NVFP4 (with RHT)
through the grouped-tensor path, which supersedes this branch's Python
routing change. Take upstream's _is_grouped_tensor_path_supported; the CUTLASS
backend hooks in at the C++ dispatch and is unaffected.

Signed-off-by: Cael Ling <caell@nvidia.com>
…lerances

Signed-off-by: Cael Ling <caell@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant