Add Graph-safe NVFP4 CUTLASS Group GEMM to unified entry points#3175
Add Graph-safe NVFP4 CUTLASS Group GEMM to unified entry points#3175cael-ling wants to merge 5 commits into
Conversation
Replace the per-expert multi-stream cuBLASLt loop in the per-tensor NVFP4
grouped path with one CUTLASS ptr-array grouped launch on SM100 (Blackwell).
Common:
- Add nvfp4_cutlass_grouped_gemm.{cuh,cu}: a single-launch per-tensor NVFP4
grouped kernel. Covers BF16 output (fprop/dgrad, overwrite), FP32 output
(wgrad, with optional in-place accumulate for Megatron wgrad fusion), and
optional fused per-group bias (fprop). Per-tensor scaling collapses the
second-level scale to one fp32 alpha per group, applied via the epilogue's
per-group alpha_ptr_array, Arch-gated on CUTLASS_ARCH_MMA_SM100_SUPPORTED.
- Wire it into nvte_multi_tensor_gemm behind the opt-in env
NVTE_NVFP4_CUTLASS_GROUPED_GEMM. M/N/K must be %128; empty (0-token)
experts schedule 0 tiles and no longer force a multi-stream fallback.
Anything unsupported falls through to the existing cuBLAS path, so default
behavior is unchanged.
- Add bench-only entry points nvte_nvfp4_grouped_per_tensor_compute_alpha /
nvte_nvfp4_grouped_per_tensor_gemm so a benchmark can precompute alpha
outside the timed region and time only the grouped GEMM launch.
PyTorch:
- Expose the two bench-only bindings (tex.nvfp4_grouped_per_tensor_compute_alpha
and tex.nvfp4_grouped_per_tensor_gemm). Not used by the production dispatch.
- Extend test_grouped_linear.py with NVFP4 cutlass-vs-multistream parity tests:
GEMM-level (uniform + uneven 128-aligned splits), empty groups, and
end-to-end GroupedLinear fwd+bwd (bias, fuse_wgrad_accumulation).
- Add a GEMM-level cutlass-vs-multistream comparison to
benchmark_grouped_linear.py (--compare-nvfp4-grouped-gemm): a DISPATCH row
(both backends via the shared dispatch) and a fair PURE row (operands
pre-swizzled, alpha precomputed; times only the GEMM).
Signed-off-by: Cael Ling <caell@nvidia.com>
…uped-tensor path Follow-up to NVIDIA#3134, which added the single-launch CUTLASS ptr-array grouped kernel for per-tensor NVFP4 but wired it into the per-expert multi-stream nvte_multi_tensor_gemm path. Per review feedback about grouped-GEMM entry-point fragmentation, move the CUTLASS kernel behind general_grouped_gemm_for_grouped_tensor as an backend and benchmark it head-to-head against the cuBLAS single-launch path (not multi-stream). Consuming the device-side grouped-tensor setup metadata also makes the CUTLASS path CUDA-graph safe. Common: - Route the CUTLASS kernel through the grouped-tensor GEMM (general_grouped_gemm_for_grouped_tensor / execute_grouped_gemm). It consumes the same device-side GroupedGemmSetupWorkspace arrays as the cuBLAS single-launch path, inheriting its graph-safety; selected by NVTE_NVFP4_CUTLASS_GROUPED_GEMM, falling back to cuBLAS when unsupported. - Replace the host-side vector launcher (run_grouped_per_tensor_gemm) and its persistent host buffer with the device-metadata entry point run_grouped_per_tensor_gemm_grouped_tensor. - Remove the old multi-stream CUTLASS dispatch branch and the two bench-only entry points (nvte_nvfp4_grouped_per_tensor_compute_alpha / nvte_nvfp4_grouped_per_tensor_gemm) and their declarations. PyTorch: - Drop the two bench-only bindings. - test_grouped_linear.py: replace the now-dead multi-stream CUTLASS parity tests with grouped-tensor backend tests -- test_nvfp4_grouped_tensor_cutlass_matches_cublas (fprop/dgrad/wgrad parity across balanced/imbalanced splits and fuse_wgrad_accumulation) and test_nvfp4_grouped_tensor_cutlass_cuda_graph_safe. - benchmark_grouped_linear.py: replace --compare-nvfp4-grouped-gemm with --compare-backends (eager per-GEMM) and --compare-backends --graph (CUDA-graphed whole-step) CUTLASS-vs-cuBLAS single-launch comparison. Signed-off-by: Cael Ling <caell@nvidia.com>
Greptile SummaryThis PR migrates the NVFP4 CUTLASS single-launch grouped GEMM kernel from standalone bench-only entry points into the existing graph-safe grouped-tensor path (
Confidence Score: 5/5Safe to merge. The CUTLASS backend is strictly opt-in, default behaviour is unchanged, and the new code shares the same device-side setup workspace as the existing cuBLAS single-launch path — inheriting its graph safety by construction. The migration is well-scoped: the new CUTLASS path is off by default and reachable only through an explicit env var + SM100 + NVFP4 guard. The previous review's multi-GPU concerns (hardcoded device 0, device-unaware persistent buffer) have both been addressed — device ID is now resolved via current_device() and the buffer is keyed per device. The two new tests (backend parity and CUDA-graph safety) exercise the critical paths on real hardware. The only items flagged are defensive coding gaps that do not affect correctness under the current call-site invariants. transformer_engine/common/gemm/cublaslt_grouped_gemm.cu (the maybe_run_nvfp4_cutlass_grouped guard) and transformer_engine/common/gemm/nvfp4_cutlass_grouped_gemm.cu (the CUTLASS initialize contract during graph capture) are worth a second look. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["nvte_grouped_gemm / nvte_grouped_gemm_with_discrete_inputA\n/ nvte_grouped_gemm_with_discrete_out"] --> B["launch_grouped_gemm_setup\n(populates GroupedGemmSetupWorkspace on device)"]
B --> C["maybe_run_nvfp4_cutlass_grouped"]
C --> D{NVTE_NVFP4_CUTLASS_GROUPED_GEMM == 1?}
D -- No --> G["execute_grouped_gemm\n(cuBLAS single-launch)"]
D -- Yes --> E{A_sel == NVFP4\nSM100\nBF16 or FP32 out?}
E -- No --> G
E -- Yes --> F["run_grouped_per_tensor_gemm_grouped_tensor\n(CUTLASS path)"]
F --> H["persistent_buffer\n(per-device scratch + workspace)"]
H --> I["build_grouped_metadata_kernel\n(device: fills problems_d, strides, SFA/SFB layouts)"]
I --> J{fp32_output?}
J -- Yes --> K["run_impl_device<float>\n(wgrad / per-group beta)"]
J -- No --> L["run_impl_device<bfloat16_t>\n(fprop / dgrad, overwrite)"]
K --> M["Gemm::initialize + Gemm::run\n(CUTLASS single-launch, CUDA-graph safe)"]
L --> M
G --> N["return to caller"]
M --> N
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
flowchart TD
A["nvte_grouped_gemm / nvte_grouped_gemm_with_discrete_inputA\n/ nvte_grouped_gemm_with_discrete_out"] --> B["launch_grouped_gemm_setup\n(populates GroupedGemmSetupWorkspace on device)"]
B --> C["maybe_run_nvfp4_cutlass_grouped"]
C --> D{NVTE_NVFP4_CUTLASS_GROUPED_GEMM == 1?}
D -- No --> G["execute_grouped_gemm\n(cuBLAS single-launch)"]
D -- Yes --> E{A_sel == NVFP4\nSM100\nBF16 or FP32 out?}
E -- No --> G
E -- Yes --> F["run_grouped_per_tensor_gemm_grouped_tensor\n(CUTLASS path)"]
F --> H["persistent_buffer\n(per-device scratch + workspace)"]
H --> I["build_grouped_metadata_kernel\n(device: fills problems_d, strides, SFA/SFB layouts)"]
I --> J{fp32_output?}
J -- Yes --> K["run_impl_device<float>\n(wgrad / per-group beta)"]
J -- No --> L["run_impl_device<bfloat16_t>\n(fprop / dgrad, overwrite)"]
K --> M["Gemm::initialize + Gemm::run\n(CUTLASS single-launch, CUDA-graph safe)"]
L --> M
G --> N["return to caller"]
M --> N
Reviews (2): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile |
…-grouped-tensor Resolve grouped_linear.py: upstream independently routes NVFP4 (with RHT) through the grouped-tensor path, which supersedes this branch's Python routing change. Take upstream's _is_grouped_tensor_path_supported; the CUTLASS backend hooks in at the C++ dispatch and is unaffected. Signed-off-by: Cael Ling <caell@nvidia.com>
…lerances Signed-off-by: Cael Ling <caell@nvidia.com>
for more information, see https://pre-commit.ci
Description
Follow-up to #3134.
Review feedback (from @vthumbe1503) was that TE already has grouped-GEMM entry points (cuBLAS single-launch, multi-stream cuBLAS, cuDSL single-launch, and now CUTLASS), and that the CUTLASS kernel should instead live behind the single grouped-tensor API. This PR migrates the kernel from #3134 accordingly:
general_grouped_gemm_for_grouped_tensor(the graph-safe grouped-tensor path), still gated byNVTE_NVFP4_CUTLASS_GROUPED_GEMM.NVTE_GROUPED_LINEAR_USE_FUSED_GROUPED_GEMM=1routesGroupedLinearonto that path.Default behavior is unchanged (backend is opt-in). #3134 is temporarily retained for comparison purposes.
Performance
SM100 (Blackwell), NVFP4 default recipe. cuBLAS single-launch vs CUTLASS single-launch, both through
general_grouped_gemm_for_grouped_tensor. Times in ms/iter (lower is better),iters=30. Reproduce with:python benchmarks/linear/benchmark_grouped_linear.py --compare-backends # eager, per-GEMM python benchmarks/linear/benchmark_grouped_linear.py --compare-backends --graph # CUDA-graphed whole step
Eager, isolated per-GEMM (cuBLAS / CUTLASS / speedup)
CUDA-graphed whole-step fwd+bwd (cuBLAS / CUTLASS / speedup)
Type of change
Changes
Common (CUDA/C++)
general_grouped_gemm_for_grouped_tensor/execute_grouped_gemm, consumingthe device-side
GroupedGemmSetupWorkspacearrays (graph-safe), selected byNVTE_NVFP4_CUTLASS_GROUPED_GEMM; fall back to cuBLAS when unsupported.run_grouped_per_tensor_gemm(and itspersistent host buffer) with
run_grouped_per_tensor_gemm_grouped_tensor.entry points.
PyTorch
tests/pytorch/test_grouped_linear.py: grouped-tensor backend parity(
test_nvfp4_grouped_tensor_cutlass_matches_cublas) + CUDA-graph safety(
test_nvfp4_grouped_tensor_cutlass_cuda_graph_safe); remove the deadmulti-stream CUTLASS tests.
benchmarks/linear/benchmark_grouped_linear.py:--compare-backends(eager per-GEMM) and
--compare-backends --graph(CUDA-graphed whole-step).Checklist:
tests/pytorch/test_grouped_linear.pypasses on SM100)