From bbcae3e721cd52f2fbe34cb74ebd6af94cd4ae3b Mon Sep 17 00:00:00 2001 From: Cael Ling Date: Tue, 16 Jun 2026 22:39:33 -0700 Subject: [PATCH 1/2] Single-launch CUTLASS grouped GEMM for per-tensor NVFP4 Replace the per-expert multi-stream cuBLASLt loop in the per-tensor NVFP4 grouped path with one CUTLASS ptr-array grouped launch on SM100 (Blackwell). Common: - Add nvfp4_cutlass_grouped_gemm.{cuh,cu}: a single-launch per-tensor NVFP4 grouped kernel. Covers BF16 output (fprop/dgrad, overwrite), FP32 output (wgrad, with optional in-place accumulate for Megatron wgrad fusion), and optional fused per-group bias (fprop). Per-tensor scaling collapses the second-level scale to one fp32 alpha per group, applied via the epilogue's per-group alpha_ptr_array, Arch-gated on CUTLASS_ARCH_MMA_SM100_SUPPORTED. - Wire it into nvte_multi_tensor_gemm behind the opt-in env NVTE_NVFP4_CUTLASS_GROUPED_GEMM. M/N/K must be %128; empty (0-token) experts schedule 0 tiles and no longer force a multi-stream fallback. Anything unsupported falls through to the existing cuBLAS path, so default behavior is unchanged. - Add bench-only entry points nvte_nvfp4_grouped_per_tensor_compute_alpha / nvte_nvfp4_grouped_per_tensor_gemm so a benchmark can precompute alpha outside the timed region and time only the grouped GEMM launch. PyTorch: - Expose the two bench-only bindings (tex.nvfp4_grouped_per_tensor_compute_alpha and tex.nvfp4_grouped_per_tensor_gemm). Not used by the production dispatch. - Extend test_grouped_linear.py with NVFP4 cutlass-vs-multistream parity tests: GEMM-level (uniform + uneven 128-aligned splits), empty groups, and end-to-end GroupedLinear fwd+bwd (bias, fuse_wgrad_accumulation). - Add a GEMM-level cutlass-vs-multistream comparison to benchmark_grouped_linear.py (--compare-nvfp4-grouped-gemm): a DISPATCH row (both backends via the shared dispatch) and a fair PURE row (operands pre-swizzled, alpha precomputed; times only the GEMM). Signed-off-by: Cael Ling --- benchmarks/linear/benchmark_grouped_linear.py | 337 ++++++++++++- tests/pytorch/test_grouped_linear.py | 303 ++++++++++++ transformer_engine/common/CMakeLists.txt | 2 + .../common/gemm/cublaslt_gemm.cu | 232 +++++++++ .../common/gemm/nvfp4_cutlass_grouped_gemm.cu | 445 ++++++++++++++++++ .../gemm/nvfp4_cutlass_grouped_gemm.cuh | 61 +++ .../common/include/transformer_engine/gemm.h | 48 ++ transformer_engine/pytorch/csrc/extensions.h | 11 + .../pytorch/csrc/extensions/gemm.cpp | 74 +++ .../pytorch/csrc/extensions/pybind.cpp | 9 + 10 files changed, 1521 insertions(+), 1 deletion(-) create mode 100644 transformer_engine/common/gemm/nvfp4_cutlass_grouped_gemm.cu create mode 100644 transformer_engine/common/gemm/nvfp4_cutlass_grouped_gemm.cuh diff --git a/benchmarks/linear/benchmark_grouped_linear.py b/benchmarks/linear/benchmark_grouped_linear.py index cf88faac4f..8c02173eb4 100644 --- a/benchmarks/linear/benchmark_grouped_linear.py +++ b/benchmarks/linear/benchmark_grouped_linear.py @@ -4,6 +4,7 @@ import argparse import os +import statistics import torch import torch.utils.benchmark as benchmark import pandas as pd @@ -15,7 +16,11 @@ NVFP4BlockScaling, ) from transformer_engine.pytorch.quantization import autocast, FP8GlobalStateManager -from contextlib import nullcontext +from transformer_engine.pytorch.cpp_extensions import general_grouped_gemm +from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer +import transformer_engine_torch as tex +from contextlib import contextmanager, nullcontext +from typing import List, Optional, Tuple """ # Profile BF16 recipe with Nsight Systems @@ -236,6 +241,291 @@ def run_benchmark_linear( return df +# ============================================================================= +# NVFP4 grouped GEMM backend comparison (GEMM-level): single-launch CUTLASS +# per-tensor grouped kernel vs the production multi-stream cuBLASLt per-expert +# loop. Both sit behind the SAME dispatch (nvte_multi_tensor_gemm); only the env +# NVTE_NVFP4_CUTLASS_GROUPED_GEMM selects between them, read fresh per call. +# Enabled with --compare-nvfp4-grouped-gemm. Operands are quantized ONCE +# (untimed); only the grouped GEMM is timed -- the fair backend comparison. +# Requires a Blackwell (SM100) build with the kernel/binding compiled in. +# ============================================================================= +_NVFP4_GG_ENV = "NVTE_NVFP4_CUTLASS_GROUPED_GEMM" + + +def _has_sm100() -> bool: + if not torch.cuda.is_available(): + return False + return torch.cuda.get_device_capability()[0] == 10 + + +@contextmanager +def _nvfp4_gg_backend(cutlass: bool): + """Toggle the cutlass/multi-stream env for the duration of a timing block.""" + prev = os.environ.get(_NVFP4_GG_ENV) + os.environ[_NVFP4_GG_ENV] = "1" if cutlass else "0" + try: + yield + finally: + if prev is None: + os.environ.pop(_NVFP4_GG_ENV, None) + else: + os.environ[_NVFP4_GG_ENV] = prev + + +def _nvfp4_gg_token_counts(num_experts: int, mean_m: int, imbalanced: bool, seed: int) -> List[int]: + """Per-expert token counts, each a multiple of 128 (the path's alignment + contract). Balanced => all equal; imbalanced => uniform in [0.25x, 1.75x] mean.""" + if not imbalanced: + return [mean_m] * num_experts + g = torch.Generator().manual_seed(seed) + lo, hi = 0.25, 1.75 + blocks_mean = max(mean_m // 128, 1) + out = [] + for _ in range(num_experts): + frac = lo + (hi - lo) * torch.rand(1, generator=g).item() + blocks = max(int(round(blocks_mean * frac)), 1) + out.append(blocks * 128) + return out + + +def _nvfp4_gg_label(Ms: List[int]) -> str: + return f"{len(Ms)}e x {'imbal' if len(set(Ms)) > 1 else Ms[0]}" + + +def _nvfp4_gg_quantizer() -> NVFP4Quantizer: + """Per-tensor NVFP4 (1D, no RHT/SR/2D) so the cutlass path is eligible.""" + return NVFP4Quantizer( + fp4_dtype=tex.DType.kFloat4E2M1, + rowwise=True, + columnwise=True, + with_rht=False, + with_post_rht_amax=False, + with_2d_quantization=False, + stochastic_rounding=False, + with_random_sign_mask=False, + ) + + +def _nvfp4_gg_quantize(hp: torch.Tensor): + q = _nvfp4_gg_quantizer() + dst = q.make_empty(hp.shape, dtype=torch.bfloat16, device=hp.device) + if hp.numel() != 0: + tex.quantize(hp, q, dst, None) + return dst + + +def _nvfp4_gg_time_us(fn, warmup: int, iters: int) -> float: + """Median wall time of fn() in microseconds via CUDA events.""" + for _ in range(warmup): + fn() + torch.cuda.synchronize() + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + times: List[float] = [] + for _ in range(iters): + start.record() + fn() + end.record() + torch.cuda.synchronize() + times.append(start.elapsed_time(end)) # ms + return statistics.median(times) * 1e3 + + +def _nvfp4_gg_build(layout: str, Ms: List[int], N: int, K: int): + """Returns (A, B, out, m_splits, grad, single_output, flops) for one layout. + Operand order matches GroupedLinear: A=weight side, B=activation/grad side.""" + dev = torch.device("cuda") + groups = len(Ms) + m = sum(Ms) + flops = 2.0 * m * N * K # one logical GEMM, summed over all groups + + def qt(*sz): + return _nvfp4_gg_quantize(torch.randn(*sz, dtype=torch.bfloat16, device=dev) * 0.5) + + if layout == "TN": # fprop: A=W(N,K), B=X(M,K) -> out(M,N) + A = [qt(N, K) for _ in range(groups)] + B = [qt(mm, K) for mm in Ms] + out = [torch.empty(m, N, dtype=torch.bfloat16, device=dev)] + return A, B, out, Ms, False, True, flops + if layout == "NN": # dgrad: A=W(N,K), B=dY(M,N) -> dgrad(M,K) + A = [qt(N, K) for _ in range(groups)] + B = [qt(mm, N) for mm in Ms] + out = [torch.empty(m, K, dtype=torch.bfloat16, device=dev)] + return A, B, out, Ms, True, True, flops + # NT wgrad: A=X(M,K), B=dY(M,N) -> wgrad(N,K), fp32 out + A = [qt(mm, K) for mm in Ms] + B = [qt(mm, N) for mm in Ms] + out = [torch.empty(N, K, dtype=torch.float32, device=dev) for _ in range(groups)] + return A, B, out, Ms, True, False, flops + + +def _nvfp4_gg_d_groups(layout: str, Ms: List[int], out: List[torch.Tensor]) -> List[torch.Tensor]: + """Per-group output tensors for the direct binding. TN/NN slice a single + packed output by tokens along dim 0; NT (wgrad) is already a per-group list.""" + if layout == "NT": + return out + big = out[0] + groups, s = [], 0 + for mm in Ms: + groups.append(big[s : s + mm]) + s += mm + return groups + + +def _nvfp4_gg_bench_pure( + layout, Ms, N, K, warmup, iters +) -> Tuple[Optional[float], Optional[float]]: + """Fair pure-GEMM comparison. Builds FRESH operands and pre-swizzles their + scales IN PLACE (untimed), then times BOTH backends on the SAME pre-swizzled + operands so the per-call scale swizzle is excluded from both timers equally: + * multi-stream-pure : general_grouped_gemm (env=0). The dispatch skips the + (already done) swizzle; the per-expert alpha is still recomputed inside + cuBLASLt, but on 4 streams it overlaps the GEMMs (~hidden). + * cutlass-pure : the direct binding with alpha precomputed (untimed) -- + the single grouped launch in isolation. + Returns (multistream_pure_us, cutlass_pure_us). Either may be None. + """ + if not hasattr(tex, "nvfp4_grouped_per_tensor_gemm"): + return None, None # binding not compiled in (stale build) + try: + A, B, out, m_splits, grad, single_output, _flops = _nvfp4_gg_build(layout, Ms, N, K) + transa = layout[0] == "T" + transb = layout[1] == "T" + d_groups = _nvfp4_gg_d_groups(layout, Ms, out) + # Pre-swizzle exactly as te_general_grouped_gemm does before the GEMM: + # A -> (rowwise=transa, columnwise=!transa); B -> (rowwise=!transb, columnwise=transb). + tex.multi_tensor_swizzle_scales_for_gemm_(A, transa, not transa) + tex.multi_tensor_swizzle_scales_for_gemm_(B, not transb, transb) + alpha = tex.nvfp4_grouped_per_tensor_compute_alpha(A, transa, B, transb) + except Exception: # noqa: BLE001 -- bench: missing kernel -> drop the columns + return None, None + + def ms_pure() -> Optional[float]: + try: + with _nvfp4_gg_backend(False): # multi-stream, on pre-swizzled operands + return _nvfp4_gg_time_us( + lambda: general_grouped_gemm( + A, B, out, [None] * len(Ms), out[0].dtype, + layout=layout, m_splits=m_splits, single_output=single_output, grad=grad, + ), + warmup, iters, + ) + except Exception: # noqa: BLE001 + return None + + def cu_pure() -> Optional[float]: + try: + return _nvfp4_gg_time_us( + lambda: tex.nvfp4_grouped_per_tensor_gemm( + A, transa, B, transb, d_groups, [], alpha, False + ), + warmup, iters, + ) + except Exception: # noqa: BLE001 + return None + + return ms_pure(), cu_pure() + + +def _nvfp4_gg_bench( + layout, Ms, N, K, warmup, iters, pure: bool = True +) -> Tuple[Optional[float], Optional[float], Optional[float], Optional[float], float, str]: + """Returns (multistream_us, cutlass_us, multistream_pure_us, cutlass_pure_us, flops, note).""" + try: + A, B, out, m_splits, grad, single_output, flops = _nvfp4_gg_build(layout, Ms, N, K) + except Exception as exc: # noqa: BLE001 + return None, None, None, None, 0.0, f"ERROR(build): {type(exc).__name__}: {str(exc)[:100]}" + + def run(): + general_grouped_gemm( + A, B, out, [None] * len(Ms), out[0].dtype, + layout=layout, m_splits=m_splits, single_output=single_output, grad=grad, + ) + + def timed(cutlass: bool) -> Optional[float]: + try: + with _nvfp4_gg_backend(cutlass): + return _nvfp4_gg_time_us(run, warmup, iters) + except Exception: # noqa: BLE001 + return None + + ms_us, cu_us = timed(False), timed(True) + # Pure path uses FRESH operands: its in-place pre-swizzle must not strip the + # per-call swizzle that the dispatch columns above legitimately pay. + ms_pure_us, cu_pure_us = ( + _nvfp4_gg_bench_pure(layout, Ms, N, K, warmup, iters) if pure else (None, None) + ) + return ms_us, cu_us, ms_pure_us, cu_pure_us, flops, "" + + +def run_nvfp4_grouped_gemm_comparison(layouts, configs, warmup, iters, want_pure) -> None: + """Driver: print a dispatch row (cutlass vs multi-stream, both via dispatch) and, + if want_pure, a fair PURE row (both pre-swizzled) for each layout x config.""" + if not _has_sm100(): + print("SKIP: NVFP4 grouped GEMM comparison requires SM100 (Blackwell)") + return + if not nvfp4_available: + print(f"SKIP: NVFP4 not available ({reason_for_no_nvfp4})") + return + + layout_label = {"TN": "TN fprop", "NN": "NN dgrad", "NT": "NT wgrad"} + pure_hdr = "" if not want_pure else ( + " + PURE row (fair kernel-vs-kernel): both pre-swizzled, swizzle excluded from both.\n" + ) + print( + f"\nNVFP4 grouped GEMM: CUTLASS vs multi-stream cuBLASLt " + f"[warmup={warmup} iters={iters}]\n" + f" DISPATCH row (real prod): multi-stream = env=0 (4-stream cuBLASLt), cutlass = env=1.\n" + f"{pure_hdr}" + f" speedup = multi-stream / cutlass (>1 => cutlass faster).\n" + ) + + def _ms(us: Optional[float]) -> str: + return f"{us / 1e3:.3f}ms" if (us is not None) else "-" + + def _spd(num: Optional[float], den: Optional[float]) -> str: + return f"{num / den:.2f}x" if (num and den and den > 0) else "-" + + header = ( + f" {'shape':<12} {'N':>5} {'K':>5} {'tok':>6} {'row':<9} " + f"{'cutlass':>10} {'multi-stream':>12} {'speedup':>8}" + ) + + def _emit(shape, n, k, tok, row, cu_us, ms_us): + print( + f" {shape:<12} {n:>5} {k:>5} {tok:>6} {row:<9} " + f"{_ms(cu_us):>10} {_ms(ms_us):>12} {_spd(ms_us, cu_us):>8}" + ) + + for layout in layouts: + print(f" [{layout_label[layout]}]") + print(header) + print(" " + "-" * (len(header) - 2)) + for Ms, N, K in configs: + ms_us, cu_us, ms_pure_us, cu_pure_us, _flops, note = _nvfp4_gg_bench( + layout, Ms, N, K, warmup, iters, pure=want_pure + ) + _emit(_nvfp4_gg_label(Ms), N, K, sum(Ms), "dispatch", cu_us, ms_us) + if want_pure: + _emit("", "", "", "", "pure", cu_pure_us, ms_pure_us) + if note: + print(f" {note}") + print() + + print( + " DISPATCH row = real prod path for both backends; both pay per-call swizzle +\n" + " per-expert alpha in the timer (cutlass runs the alpha kernels serially, multi-\n" + " stream overlaps them across 4 streams -- so it may under-sell cutlass until the\n" + " alpha launches are batched in a follow-up PR). PURE row = fair kernel-vs-kernel:\n" + " operands pre-swizzled (untimed) for BOTH; cutlass-pure also precomputes alpha\n" + " (untimed) and times only tex.nvfp4_grouped_per_tensor_gemm. If the PURE row is\n" + " blank, the kernel/binding is not compiled in -- rebuild on Blackwell. If the\n" + " DISPATCH speedup reads ~1.00x everywhere, env=1 is silently falling back to\n" + " multi-stream (also a stale build).\n" + ) + + if __name__ == "__main__": parser = argparse.ArgumentParser() @@ -279,6 +569,26 @@ def run_benchmark_linear( default=False, help="Run forward pass only, default is both forward and backward passes", ) + # NVFP4 grouped GEMM backend comparison (GEMM-level, cutlass vs multi-stream). + parser.add_argument( + "--compare-nvfp4-grouped-gemm", + action="store_true", + help="GEMM-level NVFP4 cutlass vs multi-stream cuBLASLt comparison (then exit)", + ) + parser.add_argument( + "--layouts", + nargs="+", + default=["TN", "NN", "NT"], + choices=["TN", "NN", "NT"], + help="layouts for --compare-nvfp4-grouped-gemm (TN fprop, NN dgrad, NT wgrad)", + ) + parser.add_argument( + "--no-pure", + action="store_true", + help="drop the fair PURE row in --compare-nvfp4-grouped-gemm", + ) + parser.add_argument("--gemm-warmup", type=int, default=10) + parser.add_argument("--gemm-iters", type=int, default=100) args = parser.parse_args() jagged_input_splits = None @@ -288,6 +598,31 @@ def run_benchmark_linear( print(f"Jagged input splits sum: {sum(jagged_input_splits)}") print(f"Jagged input splits num_gemms: {len(jagged_input_splits)}") + # GEMM-level NVFP4 cutlass-vs-multi-stream comparison (separate from the module + # benchmark below). Honors --jagged-input / --hidden-dim / --output-dim as a + # single custom config; otherwise uses a built-in MoE-shaped config sweep. + if args.compare_nvfp4_grouped_gemm: + if jagged_input_splits is not None: + # The per-tensor cutlass path requires tokens % 128 == 0; align up so the + # path is eligible (a non-aligned split would just fall back to cuBLAS). + Ms = [max((s + 127) // 128, 1) * 128 for s in jagged_input_splits] + gg_configs = [(Ms, args.output_dim, args.hidden_dim)] + else: + gg_configs = [ + (_nvfp4_gg_token_counts(8, 128, False, 0), 2048, 2048), # small (launch-bound) + (_nvfp4_gg_token_counts(8, 256, False, 0), 2048, 2048), + (_nvfp4_gg_token_counts(8, 512, False, 0), 2048, 2048), + (_nvfp4_gg_token_counts(8, 256, True, 1), 2048, 2048), # imbalanced + (_nvfp4_gg_token_counts(16, 256, False, 0), 2048, 2048), + (_nvfp4_gg_token_counts(16, 256, True, 2), 4096, 2048), # imbalanced, wider N + (_nvfp4_gg_token_counts(32, 128, False, 0), 2048, 2048), # many small experts + (_nvfp4_gg_token_counts(32, 256, True, 3), 2048, 2048), # many imbalanced + ] + run_nvfp4_grouped_gemm_comparison( + args.layouts, gg_configs, args.gemm_warmup, args.gemm_iters, not args.no_pure + ) + raise SystemExit(0) + use_bias = False # Set the MKN values to benchmark # Deepseek V3 EP64, SEQ_LEN=8192, topK8 diff --git a/tests/pytorch/test_grouped_linear.py b/tests/pytorch/test_grouped_linear.py index c1a9e0a407..7f05772a39 100644 --- a/tests/pytorch/test_grouped_linear.py +++ b/tests/pytorch/test_grouped_linear.py @@ -920,6 +920,309 @@ def test_grouped_gemm_cutlass_empty_groups(layout, monkeypatch): torch.testing.assert_close(tensor, torch.zeros_like(tensor), rtol=0, atol=0) +# ============================================================================= +# NVFP4 per-tensor single-launch CUTLASS grouped GEMM (Blackwell / SM100). +# +# Opt-in via NVTE_NVFP4_CUTLASS_GROUPED_GEMM; a drop-in replacement for the +# production multi-stream cuBLASLt per-expert loop. These tests mirror the BF16 +# Hopper cutlass coverage above -- GEMM-level (test_grouped_gemm), empty groups +# (test_grouped_gemm_cutlass_empty_groups) and module-level +# (test_grouped_linear_accuracy_cutlass) -- but for the NVFP4 per-tensor path, +# which is SM100-only and additionally fuses bias (fprop) / accumulate (wgrad). +# ============================================================================= +_NVFP4_CUTLASS_ENV = "NVTE_NVFP4_CUTLASS_GROUPED_GEMM" +nvfp4_cutlass_grouped_available = ( + nvfp4_available and torch.cuda.get_device_capability()[0] == 10 +) + + +def _nvfp4_pertensor_quantize(hp: torch.Tensor): + """High-precision -> per-tensor NVFP4 (rowwise+columnwise, no RHT/SR/2D) so + operand canonicalization can pick either orientation and a plain dequantize() + is a faithful neutral reference. A 0-row operand (empty expert) skips the + cast but still returns a valid NVFP4 tensor of the right shape.""" + q = NVFP4Quantizer( + fp4_dtype=tex.DType.kFloat4E2M1, + rowwise=True, + columnwise=True, + with_rht=False, + with_post_rht_amax=False, + with_2d_quantization=False, + stochastic_rounding=False, + with_random_sign_mask=False, + ) + dst = q.make_empty(hp.shape, dtype=torch.bfloat16, device=hp.device) + if hp.numel() != 0: + tex.quantize(hp, q, dst, None) + return dst + + +def _diff(ref: torch.Tensor, test: torch.Tensor): + """(max_abs, global_inf_norm_rel, ref_inf). The global rel (max_abs / ||ref||) + is robust to the near-zero output elements that make per-element rel explode.""" + ref = ref.float() + test = test.float() + max_abs = (ref - test).abs().max().item() + ref_inf = ref.abs().max().item() + return max_abs, max_abs / max(ref_inf, 1e-6), ref_inf + + +def _nvfp4_dequant_reference(A, B, *, layout: str, bias=None, init=None): + """Independent fp32 reference: dequantize the NVFP4 operands and matmul in + fp32, WITHOUT either GEMM backend. Catches a wrong orientation / scale even + when cutlass and multi-stream agree bit-for-bit. The rowwise dequant differs + from the orientation the kernel consumes (esp. NN/NT), so it is a few % off + *both* backends -- the test is that cutlass is no less accurate than the + production path, not that it beats this reference.""" + refs = [] + for g, (a, b) in enumerate(zip(A, B)): + ad = a.dequantize(dtype=torch.float32) + bd = b.dequantize(dtype=torch.float32) + if layout == "TN": # out(m,N) = X(m,K) @ W(N,K)^T ; A=W, B=X + r = bd @ ad.t() + if bias is not None: + r = r + bias[g].float().unsqueeze(0) + elif layout == "NN": # out(m,K) = dY(m,N) @ W(N,K) ; A=W, B=dY + r = bd @ ad + else: # NT: out(N,K) = dY(m,N)^T @ X(m,K) ; A=X, B=dY + r = bd.t() @ ad + if init is not None: + r = r + init[g].float() + refs.append(r) + if layout in ("TN", "NN"): + return [torch.cat(refs, dim=0)] # single_output: concat over groups (M) + return refs + + +def _run_nvfp4_grouped(A, B, out, *, layout, grad, accumulate, m_splits, single_output, + bias, cutlass, monkeypatch): + monkeypatch.setenv(_NVFP4_CUTLASS_ENV, "1" if cutlass else "0") + general_grouped_gemm( + A, + B, + out, + [None] * len(A), + out[0].dtype, + layout=layout, + m_splits=m_splits, + single_output=single_output, + grad=grad, + accumulate=accumulate, + bias=bias, + use_bias=bias is not None, + ) + + +def _assert_nvfp4_grouped_parity(out_ms, out_cu, hp): + """out_ms = multi-stream (reference), out_cu = cutlass, hp = neutral dequant + reference. Asserts (1) cutlass tracks multi-stream and (2) cutlass is no less + accurate than multi-stream against the neutral reference.""" + for ms, cu, h in zip(out_ms, out_cu, hp): + if cu.numel() == 0: + continue + abs_d, rel_d, ms_inf = _diff(ms, cu) + assert ms_inf > 1e-6, "reference output is ~0 (operand/quant bug, not a real check)" + # Backend consistency: overwrite paths are bit-identical; bias / accumulate + # add a ~1 ULP fp32/bf16 rounding diff. Allow either bound. + assert abs_d <= 5e-2 or rel_d <= 2e-2, ( + f"cutlass vs multi-stream diverged: max_abs={abs_d:.4g}, rel={rel_d:.4g}" + ) + # Correctness: cutlass no worse than production vs the neutral reference. + _, ms_hp, _ = _diff(h, ms) + _, cu_hp, _ = _diff(h, cu) + assert cu_hp <= max(0.05, ms_hp * 1.3), ( + f"cutlass less accurate than multi-stream vs dequant ref: " + f"cu={cu_hp:.4g}, ms={ms_hp:.4g}" + ) + + +def _build_nvfp4_grouped_operands(layout, m_splits, k, n, *, accumulate, use_bias, odt, dev): + """Quantize per-tensor NVFP4 operands for one capability case. Returns + (A, B, out_ms, out_cu, grad, single_output, bias, init). out_ms/out_cu are + two independent buffers (same init) so multi-stream and cutlass never alias.""" + z = len(m_splits) + m = sum(m_splits) + + def qt(*sz): + return _nvfp4_pertensor_quantize(torch.randn(*sz, dtype=torch.bfloat16, device=dev) * 0.5) + + bias = init = None + if layout == "TN": # fprop: A=W(n,k), B=X(m,k) -> out(m,n) + A = [qt(n, k) for _ in range(z)] + B = [qt(mm, k) for mm in m_splits] + out_ms = [torch.zeros(m, n, dtype=odt, device=dev)] + out_cu = [torch.zeros(m, n, dtype=odt, device=dev)] + grad, single_output = False, True + if use_bias: + bias = [torch.randn(n, dtype=odt, device=dev) for _ in range(z)] + elif layout == "NN": # dgrad: A=W(n,k), B=dY(m,n) -> dgrad(m,k) + A = [qt(n, k) for _ in range(z)] + B = [qt(mm, n) for mm in m_splits] + out_ms = [torch.zeros(m, k, dtype=odt, device=dev)] + out_cu = [torch.zeros(m, k, dtype=odt, device=dev)] + grad, single_output = True, True + else: # NT wgrad: A=X(m,k), B=dY(m,n) -> wgrad(n,k) per group + A = [qt(mm, k) for mm in m_splits] + B = [qt(mm, n) for mm in m_splits] + if accumulate: # in-place main_grad accumulate: both backends share init + init = [torch.randn(n, k, dtype=odt, device=dev) for _ in range(z)] + out_ms = [t.clone() for t in init] + out_cu = [t.clone() for t in init] + else: + out_ms = [torch.zeros(n, k, dtype=odt, device=dev) for _ in range(z)] + out_cu = [torch.zeros(n, k, dtype=odt, device=dev) for _ in range(z)] + grad, single_output = True, False + return A, B, out_ms, out_cu, grad, single_output, bias, init + + +def _run_nvfp4_gemm_case(layout, fp32_out, accumulate, use_bias, m_splits, k, n, monkeypatch): + """Build operands once, run both backends, assert parity + correctness.""" + torch.manual_seed(0) + dev = "cuda" + odt = torch.float32 if fp32_out else torch.bfloat16 + A, B, out_ms, out_cu, grad, single_output, bias, init = _build_nvfp4_grouped_operands( + layout, m_splits, k, n, accumulate=accumulate, use_bias=use_bias, odt=odt, dev=dev + ) + _run_nvfp4_grouped(A, B, out_ms, layout=layout, grad=grad, accumulate=accumulate, + m_splits=m_splits, single_output=single_output, bias=bias, + cutlass=False, monkeypatch=monkeypatch) + _run_nvfp4_grouped(A, B, out_cu, layout=layout, grad=grad, accumulate=accumulate, + m_splits=m_splits, single_output=single_output, bias=bias, + cutlass=True, monkeypatch=monkeypatch) + hp = _nvfp4_dequant_reference(A, B, layout=layout, bias=bias, init=init) + _assert_nvfp4_grouped_parity(out_ms, out_cu, hp) + + +# (id, layout, fp32_out, accumulate, use_bias) -- the capability matrix of the +# per-tensor path: TN fprop (+bias), NN dgrad, NT wgrad (fresh + accumulate). +_NVFP4_GEMM_CASES = [ + ("fprop", "TN", False, False, False), + ("fprop_bias", "TN", False, False, True), + ("dgrad", "NN", False, False, False), + ("wgrad", "NT", True, False, False), + ("wgrad_accum", "NT", True, True, False), +] + + +@pytest.mark.skipif( + not nvfp4_cutlass_grouped_available, + reason="NVFP4 CUTLASS grouped GEMM requires Blackwell (SM100) + NVFP4", +) +@pytest.mark.parametrize( + "shape", + [(4, 512, 256, 256), (8, 1024, 512, 256), (2, 256, 256, 512)], + ids=lambda s: f"{s[0]}x{s[1]}_{s[2]}x{s[3]}", +) +@pytest.mark.parametrize("case", _NVFP4_GEMM_CASES, ids=lambda c: c[0]) +def test_nvfp4_cutlass_grouped_gemm(shape, case, monkeypatch): + """GEMM-level parity. Operands are quantized ONCE (per-tensor NVFP4, all dims + %128 so the path is eligible) and fed to both backends, so any divergence is + a GEMM-backend bug, not a quant artifact. Uniform (equal) splits.""" + _, layout, fp32_out, accumulate, use_bias = case + z, m, k, n = shape + m_per = m // z + assert m_per * z == m and m_per % 128 == 0 and k % 128 == 0 and n % 128 == 0 + _run_nvfp4_gemm_case(layout, fp32_out, accumulate, use_bias, [m_per] * z, k, n, monkeypatch) + + +@pytest.mark.skipif( + not nvfp4_cutlass_grouped_available, + reason="NVFP4 CUTLASS grouped GEMM requires Blackwell (SM100) + NVFP4", +) +@pytest.mark.parametrize("case", _NVFP4_GEMM_CASES, ids=lambda c: c[0]) +def test_nvfp4_cutlass_grouped_gemm_uneven_splits(case, monkeypatch): + """Same capability matrix but with unequal (still 128-aligned) per-group + sizes, mirroring the uneven token distribution real MoE routing produces. + Stresses the per-group offset / pointer setup that equal splits cannot.""" + _, layout, fp32_out, accumulate, use_bias = case + m_splits = [128, 256, 384, 512] # all %128, deliberately unequal + _run_nvfp4_gemm_case(layout, fp32_out, accumulate, use_bias, m_splits, 256, 256, monkeypatch) + + +@pytest.mark.skipif( + not nvfp4_cutlass_grouped_available, + reason="NVFP4 CUTLASS grouped GEMM requires Blackwell (SM100) + NVFP4", +) +@pytest.mark.parametrize("layout", ["TN", "NN"]) +def test_nvfp4_cutlass_grouped_gemm_empty_groups(layout, monkeypatch): + """Empty experts (0-token groups). Tokens map to cuBLAS n for TN/NN, so an + empty group is n=0; CUTLASS schedules 0 tiles for it and the non-empty groups + must still compute (a single empty expert must NOT veto the batch). NT is + omitted: an empty NT group is K=0, which the dispatcher keeps on multi-stream + by design.""" + m_splits = [0, 128, 0, 128] # mix of empty / non-empty experts + _run_nvfp4_gemm_case(layout, False, False, False, m_splits, 256, 256, monkeypatch) + + +@pytest.mark.skipif( + not nvfp4_cutlass_grouped_available, + reason="NVFP4 CUTLASS grouped GEMM requires Blackwell (SM100) + NVFP4", +) +@pytest.mark.parametrize("bias", all_boolean) +@pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean) +@pytest.mark.parametrize("groups", [3, 6]) +def test_nvfp4_cutlass_grouped_linear(groups, bias, fuse_wgrad_accumulation, monkeypatch): + """End-to-end GroupedLinear fwd+bwd: cutlass (env=1) vs multi-stream (env=0) + through the real module wiring under a per-tensor NVFP4 recipe. Identical + weights / x / dy (RNG reset before each run) feed both, so any divergence in + output, dgrad, wgrad or dbias is a backend bug. Exercises fprop/dgrad/wgrad, + fused bias and the main_grad (fuse_wgrad_accumulation) accumulate path, across + odd/even group counts with unequal (128-aligned) per-expert token counts.""" + nvfp4_recipe = recipe.NVFP4BlockScaling( + disable_rht=True, + disable_stochastic_rounding=True, + disable_2d_quantization=True, + ) + K, N = 2048, 2048 # all dims %128 -> path eligible + m_splits = [128 * ((i % 3) + 1) for i in range(groups)] # unequal, 128-aligned + total_m = sum(m_splits) + + torch.manual_seed(0) + model = GroupedLinear( + groups, K, N, bias=bias, params_dtype=torch.bfloat16, device="cuda", + fuse_wgrad_accumulation=fuse_wgrad_accumulation, + ).eval() + x = torch.randn(total_m, K, dtype=torch.bfloat16, device="cuda", requires_grad=True) + dy = torch.randn(total_m, N, dtype=torch.bfloat16, device="cuda") + init_mg = ( + [torch.randn(N, K, dtype=torch.float32, device="cuda") for _ in range(groups)] + if fuse_wgrad_accumulation + else None + ) + + def run(cutlass: bool): + monkeypatch.setenv(_NVFP4_CUTLASS_ENV, "1" if cutlass else "0") + reset_rng_states() # identical quantization randoms in both runs + model.zero_grad(set_to_none=True) + if x.grad is not None: + x.grad = None + if fuse_wgrad_accumulation: + for i in range(groups): + getattr(model, f"weight{i}").main_grad = init_mg[i].clone() + with autocast(enabled=True, recipe=nvfp4_recipe): + out = model(x, m_splits) + out.backward(dy) + snap = {"out": out.detach().float().clone(), + "dgrad": x.grad.detach().float().clone()} + for i in range(groups): + w = getattr(model, f"weight{i}") + g = w.main_grad if fuse_wgrad_accumulation else w.grad + snap[f"wgrad{i}"] = g.detach().float().clone() + if bias: + b = getattr(model, f"bias{i}") + if b.grad is not None: + snap[f"dbias{i}"] = b.grad.detach().float().clone() + return snap + + ref = run(cutlass=False) + test = run(cutlass=True) + for key in ref: + abs_d, rel_d, _ = _diff(ref[key], test[key]) + assert abs_d <= 5e-2 or rel_d <= 2e-2, ( + f"{key}: cutlass vs multi-stream diverged (max_abs={abs_d:.4g}, rel={rel_d:.4g})" + ) + + def _pack_grouped_tensor(grouped_tensor: GroupedTensor, tensors: List[torch.Tensor]) -> None: data = grouped_tensor.rowwise_data if data is None: diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 8f96432ed8..4962540961 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -264,6 +264,7 @@ list(APPEND transformer_engine_cuda_arch_specific_sources cast/cast_grouped.cu cast/cast_grouped_dbias.cu gemm/cutlass_grouped_gemm.cu + gemm/nvfp4_cutlass_grouped_gemm.cu hadamard_transform/group_hadamard_transform.cu hadamard_transform/graph_safe_group_hadamard_transform.cu hadamard_transform/hadamard_transform.cu @@ -345,6 +346,7 @@ set_property( # CUTLASS kernels could cause hang in debug build set(CUTLASS_KERNEL_SOURCES gemm/cutlass_grouped_gemm.cu + gemm/nvfp4_cutlass_grouped_gemm.cu hadamard_transform/group_hadamard_transform_cast_fusion.cu hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu hadamard_transform/hadamard_transform_cast_fusion.cu) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index a0529c80c0..715cff008f 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -24,6 +24,7 @@ #include "../util/multi_stream.h" #include "./config.h" #include "./cutlass_grouped_gemm.cuh" +#include "./nvfp4_cutlass_grouped_gemm.cuh" namespace { @@ -1071,6 +1072,155 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor workspace, accumulate, use_split_accumulator, math_sm_count, stream); }; + // ---- Blackwell single-launch CUTLASS grouped GEMM for per-tensor NVFP4 ---- + // Opt-in (NVTE_NVFP4_CUTLASS_GROUPED_GEMM) replacement for the multi-stream + // cuBLASLt loop. Covers per-tensor NVFP4 with M/N/K % 128: + // * BF16 output, overwrite -> fprop / dgrad + // * BF16 output, overwrite, fused per-group bias -> fprop with bias + // * FP32 output, overwrite -> wgrad (fresh) + // * FP32 output, accumulate (beta=1, in-place) -> Megatron wgrad fusion + // Empty experts (0 tokens) are fine: CUTLASS schedules 0 tiles for an M==0 + // group, so a single empty expert does NOT force the whole batch back to + // multi-stream (common in real MoE routing). + // Anything else (gelu/dbias, per-token, non-128 shapes, bf16 accumulate, + // fp32+bias, ...) falls through to the existing cuBLAS path, so production + // behavior is unchanged by default. + const bool is_blackwell = (transformer_engine::cuda::sm_arch(current_device) >= 100 && + transformer_engine::cuda::sm_arch(current_device) < 110); + if (is_blackwell && + transformer_engine::getenv("NVTE_NVFP4_CUTLASS_GROUPED_GEMM", false)) { + using namespace transformer_engine; + const cublasOperation_t transaOp = transa ? CUBLAS_OP_T : CUBLAS_OP_N; + const cublasOperation_t transbOp = transb ? CUBLAS_OP_T : CUBLAS_OP_N; + + auto has_aux = [&](const NVTETensor *p, int i) -> bool { + return p != nullptr && convertNVTETensor(p[i]) != nullptr && + convertNVTETensor(p[i])->has_data(); + }; + // Fused per-group bias is supported for the forward overwrite case only + // (BF16 out, no accumulate, grad == false -> a plain epilogue add, not the + // BGRADB dbias reduction). It must be uniform across groups. + const bool bias_present = has_aux(bias, 0); + auto eligible = [&]() -> bool { + for (int i = 0; i < num_gemms; ++i) { + // gelu fusion is never taken on this path (no NVFP4 grouped caller fuses + // it; the aux pre-gelu store EVT is intentionally not implemented). + if (has_aux(pre_gelu_out, i)) return false; + if (has_aux(bias, i) != bias_present) return false; // must be uniform + const auto *iA = convertNVTETensorCheck(A[i]); + const auto *iB = convertNVTETensorCheck(B[i]); + const auto *oD = convertNVTETensorCheck(D[i]); + if (!(is_nvfp_scaling(iA->scaling_mode) && is_nvfp_scaling(iB->scaling_mode))) return false; + if (iA->row_scaled_nvfp4 || iB->row_scaled_nvfp4) return false; // per-token -> skip + // BF16 out (fprop/dgrad, overwrite) or FP32 out (wgrad). accumulate + // (Megatron wgrad fusion) reads+writes D in-place and needs FP32 out. + const bool bf16_out = oD->data.dtype == DType::kBFloat16; + const bool fp32_out = oD->data.dtype == DType::kFloat32; + if (!bf16_out && !fp32_out) return false; + if (accumulate && !fp32_out) return false; + if (bias_present) { + // Forward bias add only: BF16 overwrite, and bias matches the output. + if (grad || accumulate || !bf16_out) return false; + const auto *bt = convertNVTETensorCheck(bias[i]); + if (bt->data.dtype != DType::kBFloat16) return false; + } + const auto [A0, A1] = iA->flat_2d_dims(); + const auto [B0, B1] = iB->flat_2d_dims(); + const int m = transa ? A0 : A1; // out_features (static weight dim, > 0) + const int n = transb ? B1 : B0; // tokens (0 == empty expert, allowed) + const int k = transa ? A1 : A0; // hidden (static weight dim, > 0) + // n (tokens) may be 0 for an empty expert. CUTLASS grouped GEMM schedules + // 0 tiles for a group with M==0, so a single empty group must NOT veto the + // whole batch (that would force a multi-stream fallback in real MoE loads + // where empty experts are common). m/k are the static weight dims. + if (m <= 0 || k <= 0 || n < 0) return false; + if ((m % 128) || (n % 128) || (k % 128)) return false; + } + return true; + }; + + if (eligible()) { + std::vector a_data(num_gemms), b_data(num_gemms), a_sf(num_gemms), + b_sf(num_gemms); + std::vector alpha_ptrs(num_gemms); + std::vector d_ptrs(num_gemms); + std::vector Ms(num_gemms), Ns(num_gemms), Ks(num_gemms); + // Per-group fused bias pointers (length N == cuBLAS m per group). Left + // empty when there is no bias, which selects the no-bias kernel. + std::vector bias_data(bias_present ? num_gemms : 0); + + // Per-group second-level scale (alpha) computed with the exact same kernel + // cuBLASLt uses, so results match the multi-stream path bit-for-bit modulo + // accumulation order. + float *alpha_buf = nullptr; + NVTE_CHECK_CUDA(cudaMallocAsync(&alpha_buf, sizeof(float) * num_gemms, stream)); + const bool a_rowwise_amax = transa; // transa == T + const bool b_rowwise_amax = !transb; // transb != T + + int num_nonempty = 0; + for (int i = 0; i < num_gemms; ++i) { + const auto *iA = convertNVTETensorCheck(A[i]); + const auto *iB = convertNVTETensorCheck(B[i]); + auto *oD = convertNVTETensorCheck(D[i]); + const auto [A0, A1] = iA->flat_2d_dims(); + const auto [B0, B1] = iB->flat_2d_dims(); + const int m = transa ? A0 : A1; + const int n = transb ? B1 : B0; + const int k = transa ? A1 : A0; + + // CUTLASS M (== cuBLAS n == tokens). Record the (possibly zero) problem + // size up front; alpha pointer is valid even for empty groups. + Ms[i] = n; + Ns[i] = m; + Ks[i] = k; + alpha_ptrs[i] = alpha_buf + i; + + // Empty expert: M==0 -> CUTLASS schedules 0 tiles for this group, so its + // operand / SF / bias / alpha pointers are never dereferenced. Skip + // operand canonicalization (which asserts on operand layout/usage and may + // see no data for a 0-token tensor) and the per-tensor amax (a reduction + // over 0 rows). Leave the operand pointers null; they are not read. + if (n == 0) continue; + ++num_nonempty; + + const GemmParam param = CanonicalizeGemmInput(*iA, transaOp, *iB, transbOp, m, n, k); + + // cuBLAS (col-major) -> CUTLASS (row-major) operand swap: CUTLASS A := TE + // B, CUTLASS B := TE A; M=n, N=m, K=k. Matches the Hopper cutlass path, + // and our kernel's fixed RowMajor-A / ColMajor-B layouts == the NVFP4 TN + // (transa=T, transb=N) case. + a_data[i] = param.B; + a_sf[i] = param.B_scale_inv; + b_data[i] = param.A; + b_sf[i] = param.A_scale_inv; + d_ptrs[i] = oD->data.dptr; + if (bias_present) { + // Bias is per cuBLAS output row (length m) == CUTLASS N == per-col bias. + bias_data[i] = convertNVTETensorCheck(bias[i])->data.dptr; + } + + TensorWrapper alpha_g(alpha_buf + i, std::vector{1}, DType::kFloat32); + nvte_nvfp4_compute_per_tensor_scale(A[i], a_rowwise_amax, B[i], b_rowwise_amax, + /*alpha_in=*/1.0f, alpha_g.data(), stream); + } + + // All experts empty -> nothing to compute (every output is 0-row). Avoid a + // zero-group CUTLASS launch. + if (num_nonempty == 0) { + NVTE_CHECK_CUDA(cudaFreeAsync(alpha_buf, stream)); + return; + } + + const bool fp32_output = + convertNVTETensorCheck(D[0])->data.dtype == DType::kFloat32; + nvfp4_cutlass::run_grouped_per_tensor_gemm(a_data, b_data, a_sf, b_sf, alpha_ptrs, d_ptrs, + bias_data, Ms, Ns, Ks, fp32_output, accumulate, + stream); + NVTE_CHECK_CUDA(cudaFreeAsync(alpha_buf, stream)); + return; + } + } + // Currently only support cutlass group gemm on Hopper Arch if (!(is_hopper && use_cutlass)) { cublas_path(); @@ -1134,3 +1284,85 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor cublas_path(); } } + +// ---- Bench-only direct entry points for the per-tensor NVFP4 grouped GEMM ---- +// These split the work of the NVTE_NVFP4_CUTLASS_GROUPED_GEMM branch above so a +// benchmark can precompute the per-group alpha OUTSIDE a timed region and then +// time ONLY the single CUTLASS grouped launch (the same pure-GEMM methodology +// used for the per-token kernel). They deliberately do NOT recompute alpha or +// allocate temporaries on the hot path. The production dispatch keeps using the +// fused branch above (recomputes alpha every call for bit-exact cuBLASLt parity). + +void nvte_nvfp4_grouped_per_tensor_compute_alpha(const NVTETensor *A, const NVTETensor *B, + const int num_gemms, bool transa, bool transb, + float *alpha, cudaStream_t stream) { + NVTE_API_CALL(nvte_nvfp4_grouped_per_tensor_compute_alpha); + using namespace transformer_engine; + const bool a_rowwise_amax = transa; // transa == T + const bool b_rowwise_amax = !transb; // transb != T + for (int i = 0; i < num_gemms; ++i) { + const auto *iB = convertNVTETensorCheck(B[i]); + const auto [B0, B1] = iB->flat_2d_dims(); + const int n = transb ? B1 : B0; // tokens; 0 == empty expert (alpha unused) + if (n == 0) continue; + TensorWrapper alpha_g(alpha + i, std::vector{1}, DType::kFloat32); + nvte_nvfp4_compute_per_tensor_scale(A[i], a_rowwise_amax, B[i], b_rowwise_amax, + /*alpha_in=*/1.0f, alpha_g.data(), stream); + } +} + +void nvte_nvfp4_grouped_per_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D, + const NVTETensor *bias, const int num_gemms, bool transa, + bool transb, bool accumulate, const float *alpha, + cudaStream_t stream) { + NVTE_API_CALL(nvte_nvfp4_grouped_per_tensor_gemm); + using namespace transformer_engine; + if (num_gemms <= 0) return; + const cublasOperation_t transaOp = transa ? CUBLAS_OP_T : CUBLAS_OP_N; + const cublasOperation_t transbOp = transb ? CUBLAS_OP_T : CUBLAS_OP_N; + + auto has_aux = [&](const NVTETensor *p, int i) -> bool { + return p != nullptr && convertNVTETensor(p[i]) != nullptr && + convertNVTETensor(p[i])->has_data(); + }; + const bool bias_present = has_aux(bias, 0); + + std::vector a_data(num_gemms), b_data(num_gemms), a_sf(num_gemms), b_sf(num_gemms); + std::vector alpha_ptrs(num_gemms); + std::vector d_ptrs(num_gemms); + std::vector Ms(num_gemms), Ns(num_gemms), Ks(num_gemms); + std::vector bias_data(bias_present ? num_gemms : 0); + + int num_nonempty = 0; + for (int i = 0; i < num_gemms; ++i) { + const auto *iA = convertNVTETensorCheck(A[i]); + const auto *iB = convertNVTETensorCheck(B[i]); + auto *oD = convertNVTETensorCheck(D[i]); + const auto [A0, A1] = iA->flat_2d_dims(); + const auto [B0, B1] = iB->flat_2d_dims(); + const int m = transa ? A0 : A1; + const int n = transb ? B1 : B0; + const int k = transa ? A1 : A0; + Ms[i] = n; + Ns[i] = m; + Ks[i] = k; + alpha_ptrs[i] = alpha + i; + if (n == 0) continue; // empty expert -> 0 tiles, pointers never dereferenced + ++num_nonempty; + + const GemmParam param = CanonicalizeGemmInput(*iA, transaOp, *iB, transbOp, m, n, k); + a_data[i] = param.B; + a_sf[i] = param.B_scale_inv; + b_data[i] = param.A; + b_sf[i] = param.A_scale_inv; + d_ptrs[i] = oD->data.dptr; + if (bias_present) { + bias_data[i] = convertNVTETensorCheck(bias[i])->data.dptr; + } + } + if (num_nonempty == 0) return; + + const bool fp32_output = convertNVTETensorCheck(D[0])->data.dtype == DType::kFloat32; + nvfp4_cutlass::run_grouped_per_tensor_gemm(a_data, b_data, a_sf, b_sf, alpha_ptrs, d_ptrs, + bias_data, Ms, Ns, Ks, fp32_output, accumulate, stream); +} diff --git a/transformer_engine/common/gemm/nvfp4_cutlass_grouped_gemm.cu b/transformer_engine/common/gemm/nvfp4_cutlass_grouped_gemm.cu new file mode 100644 index 0000000000..2b2f1e888a --- /dev/null +++ b/transformer_engine/common/gemm/nvfp4_cutlass_grouped_gemm.cu @@ -0,0 +1,445 @@ +/*************************************************************************************************** + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + **************************************************************************************************/ + +// Grouped (MoE) per-tensor NVFP4xNVFP4 -> BF16 GEMM. A single CUTLASS ptr-array +// grouped launch replaces the per-expert multi-stream cuBLASLt loop used by the +// production NVFP4 grouped path (multi_stream_cublas_gemm). +// +// Design notes: +// * Mainloop / block-scaled / ptr-array config is identical to the per-token +// grouped kernel (nvfp4_cutlass_grouped_gemm on the nvFP4 per-token recipe): +// A row-major, B col-major, D = A @ B^T row-major; NVFP4 = e2m1 data + +// ue4m3 1x16 block scale-factors. The caller (dispatcher) realizes TE's +// TN direction and the cuBLAS->CUTLASS A/B swap before calling in. +// * The main structural difference vs. per-token is the epilogue: per-tensor +// scaling collapses the two per-row/col vector broadcasts into one fp32 +// scalar per group, so the no-bias case uses the default LinearCombination +// fusion with the per-group alpha_ptr_array (D = alpha[g] * acc). +// * Optional fused per-group bias (fprop) reuses the per-token grouped +// kernel's array-of-pointers EVT pattern: a hand-built Sm90 EVT computing +// D = ElementOut(alpha[g]*acc + bias[g]) with a ptr-array Sm90RowBroadcast +// bias leaf (ElementBias* -> per-group ptr_row[g]). bias[g] is a length-N +// vector broadcast along M (== cuBLAS per-row bias of length m after the +// A/B swap). The EVT is passed straight to the CollectiveBuilder, so no +// custom FusionCallbacks specialization is required. + +#include "nvfp4_cutlass_grouped_gemm.cuh" + +#include + +#include +#include +#include +#include + +#include "../common.h" +#include "../util/logging.h" +#include "common/util/system.h" +#include "cute/tensor.hpp" +#include "cutlass/cutlass.h" +#include "cutlass/detail/sm100_blockscaled_layout.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/fusion/operations.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/functional.h" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/group_array_problem_shape.hpp" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/numeric_types.h" +#include "cutlass/util/packed_stride.hpp" + +namespace transformer_engine { +namespace nvfp4_cutlass { + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +namespace cute_ = cute; +namespace fusion = cutlass::epilogue::fusion; + +// ---- Type config (mirrors the per-token grouped kernel) ------------------- +// +// Templated on the output element so we can instantiate a BF16-output kernel +// (fprop / dgrad, overwrite) and an FP32-output kernel (wgrad, optionally +// accumulating into main_grad). A second flag selects the epilogue fusion: +// stock LinearCombination (no bias) or a hand-built per-group bias EVT (fprop +// only, see below). Everything else is identical across instantiations. + +template +struct PerTensorCfg { + static constexpr bool kHasBias = kHasBias_; + using ElementA = cutlass::nv_float4_t; + using LayoutATag = cutlass::layout::RowMajor; + static constexpr int AlignmentA = 32; + + using ElementB = cutlass::nv_float4_t; + using LayoutBTag = cutlass::layout::ColumnMajor; + static constexpr int AlignmentB = 32; + + // C (accumulate source) and D (output) share the element type. For accumulate + // (wgrad), C == D == the fp32 main_grad buffer; for overwrite, C is unused. + using ElementC = ElementOutT; + using ElementD = ElementOutT; + using LayoutCTag = cutlass::layout::RowMajor; + using LayoutDTag = cutlass::layout::RowMajor; + static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; + static constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; + + using ElementAccumulator = float; + using ElementCompute = float; + using ElementScale = float; + using ArchTag = cutlass::arch::Sm100; + using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; + + using MmaTileShape = cute_::Shape; + using ClusterShape = cute_::Shape; + + // Ptr-array (grouped) schedules. NVFP4 = e2m1 data + ue4m3 SF, 1x16 vec. + using MainloopSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmNvf4Sm100; + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; + + // Per-group problem shape . + using ProblemShape = cutlass::gemm::GroupProblemShape>; + + static constexpr cutlass::FloatRoundStyle kRoundStyle = + cutlass::FloatRoundStyle::round_to_nearest; + + // Bias element / alignment (per-col bias == one vector of length N per + // group). Only used when kHasBias; matches the output element type. + using ElementBias = ElementOutT; + static constexpr int AlignmentBias = 128 / cutlass::sizeof_bits::value; + + // Per-tensor epilogue: + // * no bias -> D = alpha[g] * acc + beta * C (stock LinearCombination, + // exposes per-group alpha_ptr_array + scalar beta). + // * bias -> D = ElementOut(alpha[g] * acc + bias[g]) (fprop, overwrite). + // The bias variant is a hand-built Sm90 EVT, mirroring the per-token grouped + // kernel's array-of-pointers broadcast pattern: alpha[g] is a per-group + // scalar (Sm90ScalarBroadcastPtrArray, indexed scalar_ptr_array[g]) and + // bias[g] is a per-group length-N vector (Sm90RowBroadcast with ElementBias* + // -> IsArrayOfPointers, indexed ptr_row[g]; broadcast along M / indexed by N + // == cuBLAS per-row bias of length m after the A/B swap). The whole EVT is + // passed straight to the CollectiveBuilder, so no custom FusionCallbacks + // specialization is needed. + using BiasAlphaNode = + fusion::Sm90ScalarBroadcastPtrArray >; + using BiasNode = + fusion::Sm90RowBroadcast<0, MmaTileShape, ElementBias *, ElementCompute, + cute_::Stride, AlignmentBias>; + using BiasEVT = fusion::Sm90EVT< + fusion::Sm90Compute, + BiasAlphaNode, fusion::Sm90AccFetch, BiasNode>; // alpha[g] * acc + bias[g] + + using FusionOp = std::conditional_t< + kHasBias, BiasEVT, + cutlass::epilogue::fusion::LinearCombination >; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, ElementCompute, ElementC, + LayoutCTag *, AlignmentC, ElementD, LayoutDTag *, AlignmentD, EpilogueSchedule, + FusionOp>::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, ElementA, LayoutATag *, AlignmentA, ElementB, LayoutBTag *, + AlignmentB, ElementAccumulator, MmaTileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainloopSchedule>::CollectiveOp; + + using GemmKernel = + cutlass::gemm::kernel::GemmUniversal; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + using StrideA = typename Gemm::GemmKernel::InternalStrideA; + using StrideB = typename Gemm::GemmKernel::InternalStrideB; + using StrideC = typename Gemm::GemmKernel::InternalStrideC; + using StrideD = typename Gemm::GemmKernel::InternalStrideD; + + using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFA; + using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFB; + using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; + + using ElementADataT = typename ElementA::DataType; + using ElementBDataT = typename ElementB::DataType; + using ElementSFT = typename ElementA::ScaleFactorType; +}; + +// Query the SM count exactly once (cudaGetDeviceProperties is very slow). +static int cached_sm_count() { + static int sm = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(0); + return sm; +} + +static inline size_t align256(size_t b) { return (b + 255) / 256 * 256; } + +// Process-persistent device buffers reused across launches, to avoid per-call +// cudaMalloc/cudaFree churn. which=0 -> metadata scratch, which=1 -> CUTLASS +// workspace. Assumes grouped GEMMs are issued serially on one stream (the TE +// norm); the stream-ordered free on regrow keeps it safe under that assumption. +static void *persistent_buffer(size_t bytes, cudaStream_t stream, int which) { + static void *bufs[2] = {nullptr, nullptr}; + static size_t caps[2] = {0, 0}; + if (bytes > caps[which]) { + if (bufs[which] != nullptr) { + NVTE_CHECK_CUDA(cudaFreeAsync(bufs[which], stream)); + } + const size_t newcap = bytes + bytes / 2; // slack to avoid frequent regrows + NVTE_CHECK_CUDA(cudaMallocAsync(&bufs[which], newcap, stream)); + caps[which] = newcap; + } + return bufs[which]; +} + +// Reusable pageable host staging buffer for the single batched H2D copy of all +// per-group metadata. Pageable (not pinned) is intentional: cudaMemcpyAsync +// stages pageable source into a driver buffer before returning, so the host +// buffer can be safely overwritten by the next launch without extra sync. +static void *persistent_host_buffer(size_t bytes) { + static std::vector buf; + if (buf.size() < bytes) { + buf.resize(bytes + bytes / 2); + } + return buf.data(); +} + +template +static void run_impl(const std::vector &a_data, + const std::vector &b_data, + const std::vector &a_sf, const std::vector &b_sf, + const std::vector &alpha_ptrs, + const std::vector &d_ptrs, + const std::vector &bias_ptrs, const std::vector &Ms, + const std::vector &Ns, const std::vector &Ks, bool accumulate, + cudaStream_t stream) { + using Cfg = PerTensorCfg; + using Gemm = typename Cfg::Gemm; + using StrideA = typename Cfg::StrideA; + using StrideB = typename Cfg::StrideB; + using StrideC = typename Cfg::StrideC; + using StrideD = typename Cfg::StrideD; + using LayoutSFA = typename Cfg::LayoutSFA; + using LayoutSFB = typename Cfg::LayoutSFB; + using Sm1xxBlkScaledConfig = typename Cfg::Sm1xxBlkScaledConfig; + using ElementADataT = typename Cfg::ElementADataT; + using ElementBDataT = typename Cfg::ElementBDataT; + using ElementSFT = typename Cfg::ElementSFT; + using ElementC = typename Cfg::ElementC; + using ElementD = typename Cfg::ElementD; + using ElementBias = typename Cfg::ElementBias; + using ProblemShape = typename Cfg::ProblemShape; + + const int G = static_cast(Ms.size()); + + // Host-side per-group metadata. + std::vector problems(G); + std::vector stride_A_h(G); + std::vector stride_B_h(G); + std::vector stride_C_h(G); + std::vector stride_D_h(G); + std::vector layout_SFA_h(G); + std::vector layout_SFB_h(G); + + std::vector a_ptr_h(G); + std::vector b_ptr_h(G); + std::vector sfa_ptr_h(G); + std::vector sfb_ptr_h(G); + std::vector d_ptr_h(G); + // C source pointers. For accumulate, C == D (read-modify-write main_grad). + std::vector c_ptr_h(G); + // Per-group bias pointers (length N each). Only populated when kHasBias. + std::vector bias_ptr_h(kHasBias ? G : 0); + + for (int g = 0; g < G; ++g) { + const int M = Ms[g], N = Ns[g], K = Ks[g]; + problems[g] = {M, N, K}; + stride_A_h[g] = cutlass::make_cute_packed_stride(StrideA{}, {M, K, 1}); + stride_B_h[g] = cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1}); + stride_C_h[g] = cutlass::make_cute_packed_stride(StrideC{}, {M, N, 1}); + stride_D_h[g] = cutlass::make_cute_packed_stride(StrideD{}, {M, N, 1}); + layout_SFA_h[g] = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute_::make_shape(M, N, K, 1)); + layout_SFB_h[g] = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute_::make_shape(M, N, K, 1)); + + a_ptr_h[g] = reinterpret_cast(a_data[g]); + b_ptr_h[g] = reinterpret_cast(b_data[g]); + sfa_ptr_h[g] = reinterpret_cast(a_sf[g]); + sfb_ptr_h[g] = reinterpret_cast(b_sf[g]); + d_ptr_h[g] = reinterpret_cast(d_ptrs[g]); + c_ptr_h[g] = reinterpret_cast(d_ptrs[g]); + if constexpr (kHasBias) { + bias_ptr_h[g] = reinterpret_cast(bias_ptrs[g]); + } + } + + // Mirror all per-group metadata to device through ONE persistent scratch + // buffer with a single batched H2D copy. + const size_t need = align256(problems.size() * sizeof(problems[0])) + + align256(stride_A_h.size() * sizeof(StrideA)) + + align256(stride_B_h.size() * sizeof(StrideB)) + + align256(stride_C_h.size() * sizeof(StrideC)) + + align256(stride_D_h.size() * sizeof(StrideD)) + + align256(layout_SFA_h.size() * sizeof(LayoutSFA)) + + align256(layout_SFB_h.size() * sizeof(LayoutSFB)) + + align256(a_ptr_h.size() * sizeof(a_ptr_h[0])) + + align256(b_ptr_h.size() * sizeof(b_ptr_h[0])) + + align256(sfa_ptr_h.size() * sizeof(sfa_ptr_h[0])) + + align256(sfb_ptr_h.size() * sizeof(sfb_ptr_h[0])) + + align256(d_ptr_h.size() * sizeof(d_ptr_h[0])) + + align256(c_ptr_h.size() * sizeof(c_ptr_h[0])) + + align256(alpha_ptrs.size() * sizeof(alpha_ptrs[0])) + + align256(bias_ptr_h.size() * sizeof(const ElementBias *)); + + uint8_t *scr = static_cast(persistent_buffer(need, stream, /*which=*/0)); + uint8_t *hscr = static_cast(persistent_host_buffer(need)); + size_t off = 0; + auto put = [&](const auto &vec) { + using T = typename std::decay_t::value_type; + const size_t bytes = vec.size() * sizeof(T); + T *p = reinterpret_cast(scr + off); + std::memcpy(hscr + off, vec.data(), bytes); + off += align256(bytes); + return p; + }; + auto *problems_d = put(problems); + auto *stride_A_d = put(stride_A_h); + auto *stride_B_d = put(stride_B_h); + auto *stride_C_d = put(stride_C_h); + auto *stride_D_d = put(stride_D_h); + auto *layout_SFA_d = put(layout_SFA_h); + auto *layout_SFB_d = put(layout_SFB_h); + auto *a_ptr_d = put(a_ptr_h); + auto *b_ptr_d = put(b_ptr_h); + auto *sfa_ptr_d = put(sfa_ptr_h); + auto *sfb_ptr_d = put(sfb_ptr_h); + auto *d_ptr_d = put(d_ptr_h); + auto *c_ptr_d = put(c_ptr_h); + // Per-group second-level scale pointer array (consumed by alpha_ptr_array). + auto *alpha_ptr_array_d = put(alpha_ptrs); + // Per-group bias pointer array (only staged when kHasBias). + const ElementBias **bias_ptr_array_d = nullptr; + if constexpr (kHasBias) { + bias_ptr_array_d = put(bias_ptr_h); + } + NVTE_CHECK_CUDA(cudaMemcpyAsync(scr, hscr, off, cudaMemcpyHostToDevice, stream)); + + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + hw_info.sm_count = cached_sm_count(); + + typename Gemm::Arguments arguments; + decltype(arguments.epilogue.thread) fusion_args; + if constexpr (kHasBias) { + // Hand-built EVT args (nested), matching BiasEVT's child order: + // D = ElementOut( homogeneous_multiply_add(alpha[g], acc, bias[g]) ). + // alpha[g] via scalar_ptr_array; bias[g] via the ptr-array RowBroadcast + // (ptr_row, null_default, dRow). dRow == {} -> L-stride 0 since each + // bias_ptr_array[g] already points at group g's length-N bias base. + fusion_args = { + {/*scalars=*/{}, /*scalar_ptrs=*/{}, /*scalar_ptr_arrays=*/{alpha_ptr_array_d}, + /*dScalar=*/{}}, // alpha[g] + {}, // acc + {bias_ptr_array_d, ElementBias(0), {}}, // bias[g] + {} // homogeneous_multiply_add + }; + } else { + fusion_args.alpha = 1.0f; // overridden per-group by alpha_ptr_array + // beta == 1 -> D = alpha[g]*acc + C (accumulate into main_grad); 0 -> overwrite. + fusion_args.beta = accumulate ? 1.0f : 0.0f; + fusion_args.alpha_ptr = nullptr; + fusion_args.beta_ptr = nullptr; + fusion_args.alpha_ptr_array = alpha_ptr_array_d; + fusion_args.beta_ptr_array = nullptr; + fusion_args.dAlpha = {cute_::_0{}, cute_::_0{}, 0}; // one scalar per group + fusion_args.dBeta = {cute_::_0{}, cute_::_0{}, 0}; + } + + // ptr_C is only read when beta != 0 (no-bias accumulate); pass D's buffers so + // accumulate is in-place. The bias path is overwrite-only (no C source). + const ElementC **ptr_C = accumulate ? c_ptr_d : nullptr; + + arguments = typename Gemm::Arguments{ + cutlass::gemm::GemmUniversalMode::kGrouped, + {G, problems_d, /*host_problem_shapes=*/nullptr}, + {a_ptr_d, stride_A_d, b_ptr_d, stride_B_d, sfa_ptr_d, layout_SFA_d, sfb_ptr_d, layout_SFB_d}, + {fusion_args, ptr_C, stride_C_d, d_ptr_d, stride_D_d}, + hw_info}; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + void *workspace = nullptr; + if (workspace_size > 0) { + workspace = persistent_buffer(workspace_size, stream, /*which=*/1); + } + + Gemm gemm; + cutlass::Status status = gemm.can_implement(arguments); + NVTE_CHECK(status == cutlass::Status::kSuccess, + "CUTLASS NVFP4 grouped per-tensor GEMM cannot implement: ", + cutlassGetStatusString(status), " (num_groups=", G, ")"); + + status = gemm.initialize(arguments, workspace, stream); + NVTE_CHECK(status == cutlass::Status::kSuccess, + "CUTLASS NVFP4 grouped per-tensor GEMM initialize failed: ", + cutlassGetStatusString(status)); + + status = gemm.run(stream); + NVTE_CHECK(status == cutlass::Status::kSuccess, + "CUTLASS NVFP4 grouped per-tensor GEMM run failed: ", cutlassGetStatusString(status)); +} + +void run_grouped_per_tensor_gemm(const std::vector &a_data, + const std::vector &b_data, + const std::vector &a_sf, + const std::vector &b_sf, + const std::vector &alpha_ptrs, + const std::vector &d_ptrs, + const std::vector &bias_ptrs, + const std::vector &Ms, const std::vector &Ns, + const std::vector &Ks, bool fp32_output, bool accumulate, + cudaStream_t stream) { + static const std::vector kNoBias; + const bool has_bias = !bias_ptrs.empty(); + if (has_bias) { + // Fused per-group bias is fprop-only: BF16 output, overwrite (no accumulate). + NVTE_CHECK(!fp32_output && !accumulate, + "CUTLASS NVFP4 grouped per-tensor GEMM: fused bias requires BF16 output and " + "overwrite (no accumulate)."); + run_impl(a_data, b_data, a_sf, b_sf, alpha_ptrs, d_ptrs, + bias_ptrs, Ms, Ns, Ks, accumulate, stream); + } else if (fp32_output) { + run_impl(a_data, b_data, a_sf, b_sf, alpha_ptrs, d_ptrs, kNoBias, Ms, + Ns, Ks, accumulate, stream); + } else { + NVTE_CHECK(!accumulate, + "CUTLASS NVFP4 grouped per-tensor GEMM: accumulate requires FP32 output."); + run_impl(a_data, b_data, a_sf, b_sf, alpha_ptrs, + d_ptrs, kNoBias, Ms, Ns, Ks, accumulate, + stream); + } +} + +#else // !CUTLASS_ARCH_MMA_SM100_SUPPORTED + +void run_grouped_per_tensor_gemm(const std::vector &, const std::vector &, + const std::vector &, const std::vector &, + const std::vector &, const std::vector &, + const std::vector &, const std::vector &, + const std::vector &, const std::vector &, bool, bool, + cudaStream_t) { + NVTE_ERROR( + "CUTLASS NVFP4 grouped per-tensor GEMM requires SM100 (Blackwell). Build with " + "sm_100a/sm_100f."); +} +#endif // CUTLASS_ARCH_MMA_SM100_SUPPORTED + +} // namespace nvfp4_cutlass +} // namespace transformer_engine diff --git a/transformer_engine/common/gemm/nvfp4_cutlass_grouped_gemm.cuh b/transformer_engine/common/gemm/nvfp4_cutlass_grouped_gemm.cuh new file mode 100644 index 0000000000..3a1af9df7d --- /dev/null +++ b/transformer_engine/common/gemm/nvfp4_cutlass_grouped_gemm.cuh @@ -0,0 +1,61 @@ +/*************************************************************************************************** + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + **************************************************************************************************/ + +/** + * @file nvfp4_cutlass_grouped_gemm.cuh + * @brief Single-launch CUTLASS grouped (MoE) GEMM for *per-tensor* NVFP4 on + * SM100 (Blackwell). Replaces the multi-stream cuBLASLt loop in the + * production NVFP4 grouped path with one CUTLASS ptr-array launch. + * + * Compared to the per-token grouped kernel, the per-tensor second-level scale + * collapses to a single fp32 scalar per group + * alpha[g] = amax_A[g] * amax_B[g] / (fp4_max^2 * fp8_max^2), + * applied through the epilogue's per-group alpha_ptr_array (default + * LinearCombination), so no vector row/col broadcast EVT is needed. + * + * The launcher takes raw device-pointer vectors so the TE-tensor / scale / + * layout extraction (and the cuBLAS->CUTLASS A/B swap) lives in the dispatcher + * in cublaslt_gemm.cu, reusing CanonicalizeGemmInput for parity with cuBLASLt. + */ + +#pragma once + +#include + +#include + +namespace transformer_engine { +namespace nvfp4_cutlass { + +// Single-launch per-tensor NVFP4 grouped GEMM. CUTLASS computes, per group, +// D[g] = out_dtype(alpha[g] * (A[g] @ B[g]^T) + beta * C[g]) +// with A[g] row-major (M,K) FP4, B[g] col-major (N,K) FP4, D[g] row-major +// (M,N). a_sf/b_sf are the swizzled e4m3 1x16 block-scale buffers for A/B. +// alpha_ptrs[g] points to a single device fp32 holding the per-group global +// second-level scale. M, N, K must be multiples of 128. +// +// Output / accumulate / bias modes: +// * fp32_output == false -> BF16 output, overwrite (fprop / dgrad). +// * fp32_output == true -> FP32 output (wgrad). With accumulate == true, +// beta = 1 and C == D, i.e. D (== main_grad) is read-modify-written +// in-place (Megatron wgrad fusion). accumulate requires fp32_output. +// * bias_ptrs non-empty -> fused per-group bias (fprop): each bias_ptrs[g] +// points to group g's length-N bias vector (D-element dtype), added in the +// epilogue (D[g] = alpha[g]*acc + bias[g]). Requires BF16 output and +// overwrite (no accumulate). Pass an empty vector when there is no bias. +void run_grouped_per_tensor_gemm(const std::vector &a_data, + const std::vector &b_data, + const std::vector &a_sf, + const std::vector &b_sf, + const std::vector &alpha_ptrs, + const std::vector &d_ptrs, + const std::vector &bias_ptrs, + const std::vector &Ms, const std::vector &Ns, + const std::vector &Ks, bool fp32_output, bool accumulate, + cudaStream_t stream); + +} // namespace nvfp4_cutlass +} // namespace transformer_engine diff --git a/transformer_engine/common/include/transformer_engine/gemm.h b/transformer_engine/common/include/transformer_engine/gemm.h index a99e0946ef..bffd61acd9 100644 --- a/transformer_engine/common/include/transformer_engine/gemm.h +++ b/transformer_engine/common/include/transformer_engine/gemm.h @@ -298,6 +298,54 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor bool accumulate, bool use_split_accumulator, int math_sm_count, cudaStream_t stream); +/*! \brief Compute the per-group second-level scale (alpha) for a per-tensor NVFP4 grouped GEMM. + * + * Benchmark-only helper. Mirrors the per-expert alpha computation that the + * production dispatch (nvte_multi_tensor_gemm, NVTE_NVFP4_CUTLASS_GROUPED_GEMM + * path) performs inline, but lets a caller run it ONCE outside a timed region so + * that nvte_nvfp4_grouped_per_tensor_gemm can be timed in isolation (matching the + * pure-GEMM timing methodology used for the per-token kernel). NOT used by the + * production path, which always recomputes alpha to stay bit-for-bit with cuBLASLt. + * + * \param[in] A List of A matrices (per-tensor NVFP4). + * \param[in] B List of B matrices (per-tensor NVFP4). + * \param[in] num_gemms Number of GEMMs in the group. + * \param[in] transa Whether A matrices are transposed (cuBLAS convention). + * \param[in] transb Whether B matrices are transposed (cuBLAS convention). + * \param[out] alpha Device buffer of num_gemms float32; alpha[g] is written for non-empty + * groups. Slots for empty (0-token) groups are left untouched. + * \param[in] stream CUDA stream. + */ +void nvte_nvfp4_grouped_per_tensor_compute_alpha(const NVTETensor *A, const NVTETensor *B, + const int num_gemms, bool transa, bool transb, + float *alpha, cudaStream_t stream); + +/*! \brief Single-launch per-tensor NVFP4 grouped GEMM with a precomputed alpha. + * + * Benchmark-only helper. Equivalent to the NVTE_NVFP4_CUTLASS_GROUPED_GEMM branch + * of nvte_multi_tensor_gemm, but takes a caller-provided per-group alpha buffer + * (no per-expert recompute, no temporary allocation) so that ONLY the CUTLASS + * grouped-GEMM launch is timed. The scaling factors of A/B must already be in the + * GEMM-swizzled layout (e.g. via the pytorch multi_tensor_swizzle_scales_for_gemm_ + * helper), exactly as the production grouped-GEMM wrapper arranges before calling + * nvte_multi_tensor_gemm. Empty (0-token) groups schedule 0 tiles and are skipped. + * + * \param[in] A List of A matrices (per-tensor NVFP4, swizzled scales). + * \param[in] B List of B matrices (per-tensor NVFP4, swizzled scales). + * \param[in,out] D List of output matrices (BF16 or FP32). + * \param[in] bias List of per-group bias tensors, or NULL for no bias. + * \param[in] num_gemms Number of GEMMs in the group. + * \param[in] transa Whether A matrices are transposed (cuBLAS convention). + * \param[in] transb Whether B matrices are transposed (cuBLAS convention). + * \param[in] accumulate Whether to accumulate into D in place (requires FP32 output). + * \param[in] alpha Device buffer of num_gemms float32 second-level scales. + * \param[in] stream CUDA stream. + */ +void nvte_nvfp4_grouped_per_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D, + const NVTETensor *bias, const int num_gemms, bool transa, + bool transb, bool accumulate, const float *alpha, + cudaStream_t stream); + /*! \brief Return the required size in bytes for the setup workspace of grouped GEMM. * * The setup workspace stores pointer arrays and per-matrix dimension arrays used diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 2b4f899e1d..bd160c94d3 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -163,6 +163,17 @@ std::optional> te_general_grouped_gemm( std::vector pre_gelu_out, bool grad, std::vector workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator, int math_sm_count); +// Bench-only: precompute per-group alpha for a per-tensor NVFP4 grouped GEMM. +at::Tensor nvfp4_grouped_per_tensor_compute_alpha(std::vector A, bool transa, + std::vector B, bool transb); + +// Bench-only: single-launch per-tensor NVFP4 grouped GEMM with precomputed alpha +// Times only the GEMM. +void nvfp4_grouped_per_tensor_gemm(std::vector A, bool transa, + std::vector B, bool transb, + std::vector D, std::vector bias, + at::Tensor alpha, bool accumulate); + py::object te_general_grouped_gemm_for_grouped_tensor( py::handle A, bool transa, py::handle B, bool transb, py::handle D, py::object bias, std::optional bias_scale, at::Tensor alpha, at::Tensor beta, diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index b1e552ec8b..8a52dbff0f 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -636,6 +636,80 @@ std::optional> te_general_grouped_gemm( return bias; } +at::Tensor nvfp4_grouped_per_tensor_compute_alpha(std::vector A, bool transa, + std::vector B, bool transb) { + // Bench-only: precompute the per-group second-level scale (alpha) for a + // per-tensor NVFP4 grouped GEMM so the GEMM launch can be timed in isolation. + const size_t num_gemms = A.size(); + NVTE_CHECK(B.size() == num_gemms, "A and B must have matching lengths (", num_gemms, " vs ", + B.size(), ")."); + const auto none = py::none(); + std::vector te_A_wrappers, te_B_wrappers; + for (size_t i = 0; i < num_gemms; i++) { + te_A_wrappers.emplace_back(makeTransformerEngineTensor(A[i], none)); + te_B_wrappers.emplace_back(makeTransformerEngineTensor(B[i], none)); + } + std::vector te_A_vector, te_B_vector; + for (size_t i = 0; i < num_gemms; i++) { + te_A_vector.emplace_back(te_A_wrappers[i].data()); + te_B_vector.emplace_back(te_B_wrappers[i].data()); + } + + auto opts = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + at::Tensor alpha = at::zeros({static_cast(num_gemms)}, opts); + at::cuda::CUDAGuard device_guard(alpha.device()); + NVTE_SCOPED_GIL_RELEASE({ + nvte_nvfp4_grouped_per_tensor_compute_alpha( + te_A_vector.data(), te_B_vector.data(), static_cast(num_gemms), transa, transb, + reinterpret_cast(alpha.data_ptr()), at::cuda::getCurrentCUDAStream()); + }); + return alpha; +} + +void nvfp4_grouped_per_tensor_gemm(std::vector A, bool transa, + std::vector B, bool transb, + std::vector D, std::vector bias, + at::Tensor alpha, bool accumulate) { + // Bench-only: single-launch per-tensor NVFP4 grouped GEMM with a precomputed + // alpha. Scales of A/B must already be GEMM-swizzled by the caller (e.g. via + // multi_tensor_swizzle_scales_for_gemm_). Only the GEMM launch happens here. + const size_t num_gemms = A.size(); + NVTE_CHECK(B.size() == num_gemms && D.size() == num_gemms, + "A, B, D must have matching lengths."); + const bool have_bias = !bias.empty(); + NVTE_CHECK(!have_bias || bias.size() == num_gemms, + "bias must be empty or have length num_gemms."); + at::cuda::CUDAGuard device_guard(D[0].device()); + + const auto none = py::none(); + std::vector te_A_wrappers, te_B_wrappers, te_D_wrappers, te_bias_wrappers; + for (size_t i = 0; i < num_gemms; i++) { + te_A_wrappers.emplace_back(makeTransformerEngineTensor(A[i], none)); + te_B_wrappers.emplace_back(makeTransformerEngineTensor(B[i], none)); + te_D_wrappers.emplace_back(makeTransformerEngineTensor(D[i])); + if (have_bias) { + te_bias_wrappers.emplace_back(makeTransformerEngineTensor(bias[i])); + } + } + std::vector te_A_vector, te_B_vector, te_D_vector, te_bias_vector; + for (size_t i = 0; i < num_gemms; i++) { + te_A_vector.emplace_back(te_A_wrappers[i].data()); + te_B_vector.emplace_back(te_B_wrappers[i].data()); + te_D_vector.emplace_back(te_D_wrappers[i].data()); + if (have_bias) { + te_bias_vector.emplace_back(te_bias_wrappers[i].data()); + } + } + + NVTE_SCOPED_GIL_RELEASE({ + nvte_nvfp4_grouped_per_tensor_gemm( + te_A_vector.data(), te_B_vector.data(), te_D_vector.data(), + have_bias ? te_bias_vector.data() : nullptr, static_cast(num_gemms), transa, transb, + accumulate, reinterpret_cast(alpha.data_ptr()), + at::cuda::getCurrentCUDAStream()); + }); +} + py::object te_general_grouped_gemm_for_grouped_tensor( py::handle A, bool transa, py::handle B, bool transb, py::handle D, py::object bias, std::optional bias_scale, at::Tensor alpha, at::Tensor beta, diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index d6089b1e01..e8249e5e13 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -348,6 +348,15 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "Required workspace size for grouped GEMM setup"); m.def("te_general_grouped_gemm", &transformer_engine::pytorch::te_general_grouped_gemm, "Grouped GEMM"); + m.def("nvfp4_grouped_per_tensor_compute_alpha", + &transformer_engine::pytorch::nvfp4_grouped_per_tensor_compute_alpha, + "Bench-only: precompute per-group alpha for per-tensor NVFP4 grouped GEMM", + py::arg("A"), py::arg("transa"), py::arg("B"), py::arg("transb")); + m.def("nvfp4_grouped_per_tensor_gemm", + &transformer_engine::pytorch::nvfp4_grouped_per_tensor_gemm, + "Bench-only: single-launch per-tensor NVFP4 grouped GEMM with precomputed alpha", + py::arg("A"), py::arg("transa"), py::arg("B"), py::arg("transb"), py::arg("D"), + py::arg("bias"), py::arg("alpha"), py::arg("accumulate")); m.def("te_general_grouped_gemm_for_grouped_tensor", &transformer_engine::pytorch::te_general_grouped_gemm_for_grouped_tensor, "Grouped GEMM for GroupedTensor"); From bebf1b999d84297852742b7f64214371a0379b76 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 17 Jun 2026 08:02:37 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- benchmarks/linear/benchmark_grouped_linear.py | 50 +++++++++---- tests/pytorch/test_grouped_linear.py | 70 ++++++++++++------- .../common/gemm/cublaslt_gemm.cu | 9 ++- .../common/gemm/nvfp4_cutlass_grouped_gemm.cu | 68 +++++++++--------- .../gemm/nvfp4_cutlass_grouped_gemm.cuh | 17 ++--- .../pytorch/csrc/extensions/gemm.cpp | 13 ++-- .../pytorch/csrc/extensions/pybind.cpp | 4 +- 7 files changed, 132 insertions(+), 99 deletions(-) diff --git a/benchmarks/linear/benchmark_grouped_linear.py b/benchmarks/linear/benchmark_grouped_linear.py index 8c02173eb4..88228b8412 100644 --- a/benchmarks/linear/benchmark_grouped_linear.py +++ b/benchmarks/linear/benchmark_grouped_linear.py @@ -406,10 +406,18 @@ def ms_pure() -> Optional[float]: with _nvfp4_gg_backend(False): # multi-stream, on pre-swizzled operands return _nvfp4_gg_time_us( lambda: general_grouped_gemm( - A, B, out, [None] * len(Ms), out[0].dtype, - layout=layout, m_splits=m_splits, single_output=single_output, grad=grad, + A, + B, + out, + [None] * len(Ms), + out[0].dtype, + layout=layout, + m_splits=m_splits, + single_output=single_output, + grad=grad, ), - warmup, iters, + warmup, + iters, ) except Exception: # noqa: BLE001 return None @@ -420,7 +428,8 @@ def cu_pure() -> Optional[float]: lambda: tex.nvfp4_grouped_per_tensor_gemm( A, transa, B, transb, d_groups, [], alpha, False ), - warmup, iters, + warmup, + iters, ) except Exception: # noqa: BLE001 return None @@ -439,8 +448,15 @@ def _nvfp4_gg_bench( def run(): general_grouped_gemm( - A, B, out, [None] * len(Ms), out[0].dtype, - layout=layout, m_splits=m_splits, single_output=single_output, grad=grad, + A, + B, + out, + [None] * len(Ms), + out[0].dtype, + layout=layout, + m_splits=m_splits, + single_output=single_output, + grad=grad, ) def timed(cutlass: bool) -> Optional[float]: @@ -470,15 +486,19 @@ def run_nvfp4_grouped_gemm_comparison(layouts, configs, warmup, iters, want_pure return layout_label = {"TN": "TN fprop", "NN": "NN dgrad", "NT": "NT wgrad"} - pure_hdr = "" if not want_pure else ( - " + PURE row (fair kernel-vs-kernel): both pre-swizzled, swizzle excluded from both.\n" + pure_hdr = ( + "" + if not want_pure + else ( + " + PURE row (fair kernel-vs-kernel): both pre-swizzled, swizzle excluded from both.\n" + ) ) print( - f"\nNVFP4 grouped GEMM: CUTLASS vs multi-stream cuBLASLt " + "\nNVFP4 grouped GEMM: CUTLASS vs multi-stream cuBLASLt " f"[warmup={warmup} iters={iters}]\n" - f" DISPATCH row (real prod): multi-stream = env=0 (4-stream cuBLASLt), cutlass = env=1.\n" + " DISPATCH row (real prod): multi-stream = env=0 (4-stream cuBLASLt), cutlass = env=1.\n" f"{pure_hdr}" - f" speedup = multi-stream / cutlass (>1 => cutlass faster).\n" + " speedup = multi-stream / cutlass (>1 => cutlass faster).\n" ) def _ms(us: Optional[float]) -> str: @@ -609,14 +629,14 @@ def _emit(shape, n, k, tok, row, cu_us, ms_us): gg_configs = [(Ms, args.output_dim, args.hidden_dim)] else: gg_configs = [ - (_nvfp4_gg_token_counts(8, 128, False, 0), 2048, 2048), # small (launch-bound) + (_nvfp4_gg_token_counts(8, 128, False, 0), 2048, 2048), # small (launch-bound) (_nvfp4_gg_token_counts(8, 256, False, 0), 2048, 2048), (_nvfp4_gg_token_counts(8, 512, False, 0), 2048, 2048), - (_nvfp4_gg_token_counts(8, 256, True, 1), 2048, 2048), # imbalanced + (_nvfp4_gg_token_counts(8, 256, True, 1), 2048, 2048), # imbalanced (_nvfp4_gg_token_counts(16, 256, False, 0), 2048, 2048), - (_nvfp4_gg_token_counts(16, 256, True, 2), 4096, 2048), # imbalanced, wider N + (_nvfp4_gg_token_counts(16, 256, True, 2), 4096, 2048), # imbalanced, wider N (_nvfp4_gg_token_counts(32, 128, False, 0), 2048, 2048), # many small experts - (_nvfp4_gg_token_counts(32, 256, True, 3), 2048, 2048), # many imbalanced + (_nvfp4_gg_token_counts(32, 256, True, 3), 2048, 2048), # many imbalanced ] run_nvfp4_grouped_gemm_comparison( args.layouts, gg_configs, args.gemm_warmup, args.gemm_iters, not args.no_pure diff --git a/tests/pytorch/test_grouped_linear.py b/tests/pytorch/test_grouped_linear.py index 7f05772a39..7834f08293 100644 --- a/tests/pytorch/test_grouped_linear.py +++ b/tests/pytorch/test_grouped_linear.py @@ -931,9 +931,7 @@ def test_grouped_gemm_cutlass_empty_groups(layout, monkeypatch): # which is SM100-only and additionally fuses bias (fprop) / accumulate (wgrad). # ============================================================================= _NVFP4_CUTLASS_ENV = "NVTE_NVFP4_CUTLASS_GROUPED_GEMM" -nvfp4_cutlass_grouped_available = ( - nvfp4_available and torch.cuda.get_device_capability()[0] == 10 -) +nvfp4_cutlass_grouped_available = nvfp4_available and torch.cuda.get_device_capability()[0] == 10 def _nvfp4_pertensor_quantize(hp: torch.Tensor): @@ -994,8 +992,9 @@ def _nvfp4_dequant_reference(A, B, *, layout: str, bias=None, init=None): return refs -def _run_nvfp4_grouped(A, B, out, *, layout, grad, accumulate, m_splits, single_output, - bias, cutlass, monkeypatch): +def _run_nvfp4_grouped( + A, B, out, *, layout, grad, accumulate, m_splits, single_output, bias, cutlass, monkeypatch +): monkeypatch.setenv(_NVFP4_CUTLASS_ENV, "1" if cutlass else "0") general_grouped_gemm( A, @@ -1024,16 +1023,15 @@ def _assert_nvfp4_grouped_parity(out_ms, out_cu, hp): assert ms_inf > 1e-6, "reference output is ~0 (operand/quant bug, not a real check)" # Backend consistency: overwrite paths are bit-identical; bias / accumulate # add a ~1 ULP fp32/bf16 rounding diff. Allow either bound. - assert abs_d <= 5e-2 or rel_d <= 2e-2, ( - f"cutlass vs multi-stream diverged: max_abs={abs_d:.4g}, rel={rel_d:.4g}" - ) + assert ( + abs_d <= 5e-2 or rel_d <= 2e-2 + ), f"cutlass vs multi-stream diverged: max_abs={abs_d:.4g}, rel={rel_d:.4g}" # Correctness: cutlass no worse than production vs the neutral reference. _, ms_hp, _ = _diff(h, ms) _, cu_hp, _ = _diff(h, cu) - assert cu_hp <= max(0.05, ms_hp * 1.3), ( - f"cutlass less accurate than multi-stream vs dequant ref: " - f"cu={cu_hp:.4g}, ms={ms_hp:.4g}" - ) + assert cu_hp <= max( + 0.05, ms_hp * 1.3 + ), f"cutlass less accurate than multi-stream vs dequant ref: cu={cu_hp:.4g}, ms={ms_hp:.4g}" def _build_nvfp4_grouped_operands(layout, m_splits, k, n, *, accumulate, use_bias, odt, dev): @@ -1083,12 +1081,32 @@ def _run_nvfp4_gemm_case(layout, fp32_out, accumulate, use_bias, m_splits, k, n, A, B, out_ms, out_cu, grad, single_output, bias, init = _build_nvfp4_grouped_operands( layout, m_splits, k, n, accumulate=accumulate, use_bias=use_bias, odt=odt, dev=dev ) - _run_nvfp4_grouped(A, B, out_ms, layout=layout, grad=grad, accumulate=accumulate, - m_splits=m_splits, single_output=single_output, bias=bias, - cutlass=False, monkeypatch=monkeypatch) - _run_nvfp4_grouped(A, B, out_cu, layout=layout, grad=grad, accumulate=accumulate, - m_splits=m_splits, single_output=single_output, bias=bias, - cutlass=True, monkeypatch=monkeypatch) + _run_nvfp4_grouped( + A, + B, + out_ms, + layout=layout, + grad=grad, + accumulate=accumulate, + m_splits=m_splits, + single_output=single_output, + bias=bias, + cutlass=False, + monkeypatch=monkeypatch, + ) + _run_nvfp4_grouped( + A, + B, + out_cu, + layout=layout, + grad=grad, + accumulate=accumulate, + m_splits=m_splits, + single_output=single_output, + bias=bias, + cutlass=True, + monkeypatch=monkeypatch, + ) hp = _nvfp4_dequant_reference(A, B, layout=layout, bias=bias, init=init) _assert_nvfp4_grouped_parity(out_ms, out_cu, hp) @@ -1179,7 +1197,12 @@ def test_nvfp4_cutlass_grouped_linear(groups, bias, fuse_wgrad_accumulation, mon torch.manual_seed(0) model = GroupedLinear( - groups, K, N, bias=bias, params_dtype=torch.bfloat16, device="cuda", + groups, + K, + N, + bias=bias, + params_dtype=torch.bfloat16, + device="cuda", fuse_wgrad_accumulation=fuse_wgrad_accumulation, ).eval() x = torch.randn(total_m, K, dtype=torch.bfloat16, device="cuda", requires_grad=True) @@ -1202,8 +1225,7 @@ def run(cutlass: bool): with autocast(enabled=True, recipe=nvfp4_recipe): out = model(x, m_splits) out.backward(dy) - snap = {"out": out.detach().float().clone(), - "dgrad": x.grad.detach().float().clone()} + snap = {"out": out.detach().float().clone(), "dgrad": x.grad.detach().float().clone()} for i in range(groups): w = getattr(model, f"weight{i}") g = w.main_grad if fuse_wgrad_accumulation else w.grad @@ -1218,9 +1240,9 @@ def run(cutlass: bool): test = run(cutlass=True) for key in ref: abs_d, rel_d, _ = _diff(ref[key], test[key]) - assert abs_d <= 5e-2 or rel_d <= 2e-2, ( - f"{key}: cutlass vs multi-stream diverged (max_abs={abs_d:.4g}, rel={rel_d:.4g})" - ) + assert ( + abs_d <= 5e-2 or rel_d <= 2e-2 + ), f"{key}: cutlass vs multi-stream diverged (max_abs={abs_d:.4g}, rel={rel_d:.4g})" def _pack_grouped_tensor(grouped_tensor: GroupedTensor, tensors: List[torch.Tensor]) -> None: diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 715cff008f..9a508c63b9 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -1087,8 +1087,7 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor // behavior is unchanged by default. const bool is_blackwell = (transformer_engine::cuda::sm_arch(current_device) >= 100 && transformer_engine::cuda::sm_arch(current_device) < 110); - if (is_blackwell && - transformer_engine::getenv("NVTE_NVFP4_CUTLASS_GROUPED_GEMM", false)) { + if (is_blackwell && transformer_engine::getenv("NVTE_NVFP4_CUTLASS_GROUPED_GEMM", false)) { using namespace transformer_engine; const cublasOperation_t transaOp = transa ? CUBLAS_OP_T : CUBLAS_OP_N; const cublasOperation_t transbOp = transb ? CUBLAS_OP_T : CUBLAS_OP_N; @@ -1211,8 +1210,7 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor return; } - const bool fp32_output = - convertNVTETensorCheck(D[0])->data.dtype == DType::kFloat32; + const bool fp32_output = convertNVTETensorCheck(D[0])->data.dtype == DType::kFloat32; nvfp4_cutlass::run_grouped_per_tensor_gemm(a_data, b_data, a_sf, b_sf, alpha_ptrs, d_ptrs, bias_data, Ms, Ns, Ks, fp32_output, accumulate, stream); @@ -1364,5 +1362,6 @@ void nvte_nvfp4_grouped_per_tensor_gemm(const NVTETensor *A, const NVTETensor *B const bool fp32_output = convertNVTETensorCheck(D[0])->data.dtype == DType::kFloat32; nvfp4_cutlass::run_grouped_per_tensor_gemm(a_data, b_data, a_sf, b_sf, alpha_ptrs, d_ptrs, - bias_data, Ms, Ns, Ks, fp32_output, accumulate, stream); + bias_data, Ms, Ns, Ks, fp32_output, accumulate, + stream); } diff --git a/transformer_engine/common/gemm/nvfp4_cutlass_grouped_gemm.cu b/transformer_engine/common/gemm/nvfp4_cutlass_grouped_gemm.cu index 2b2f1e888a..cfe5a43223 100644 --- a/transformer_engine/common/gemm/nvfp4_cutlass_grouped_gemm.cu +++ b/transformer_engine/common/gemm/nvfp4_cutlass_grouped_gemm.cu @@ -26,8 +26,6 @@ // A/B swap). The EVT is passed straight to the CollectiveBuilder, so no // custom FusionCallbacks specialization is required. -#include "nvfp4_cutlass_grouped_gemm.cuh" - #include #include @@ -56,6 +54,7 @@ #include "cutlass/gemm/kernel/gemm_universal.hpp" #include "cutlass/numeric_types.h" #include "cutlass/util/packed_stride.hpp" +#include "nvfp4_cutlass_grouped_gemm.cuh" namespace transformer_engine { namespace nvfp4_cutlass { @@ -130,8 +129,8 @@ struct PerTensorCfg { // passed straight to the CollectiveBuilder, so no custom FusionCallbacks // specialization is needed. using BiasAlphaNode = - fusion::Sm90ScalarBroadcastPtrArray >; + fusion::Sm90ScalarBroadcastPtrArray>; using BiasNode = fusion::Sm90RowBroadcast<0, MmaTileShape, ElementBias *, ElementCompute, cute_::Stride, AlignmentBias>; @@ -139,10 +138,10 @@ struct PerTensorCfg { fusion::Sm90Compute, BiasAlphaNode, fusion::Sm90AccFetch, BiasNode>; // alpha[g] * acc + bias[g] - using FusionOp = std::conditional_t< - kHasBias, BiasEVT, - cutlass::epilogue::fusion::LinearCombination >; + using FusionOp = + std::conditional_t>; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< ArchTag, OperatorClass, MmaTileShape, ClusterShape, @@ -215,13 +214,12 @@ static void *persistent_host_buffer(size_t bytes) { template static void run_impl(const std::vector &a_data, - const std::vector &b_data, - const std::vector &a_sf, const std::vector &b_sf, + const std::vector &b_data, const std::vector &a_sf, + const std::vector &b_sf, const std::vector &alpha_ptrs, - const std::vector &d_ptrs, - const std::vector &bias_ptrs, const std::vector &Ms, - const std::vector &Ns, const std::vector &Ks, bool accumulate, - cudaStream_t stream) { + const std::vector &d_ptrs, const std::vector &bias_ptrs, + const std::vector &Ms, const std::vector &Ns, + const std::vector &Ks, bool accumulate, cudaStream_t stream) { using Cfg = PerTensorCfg; using Gemm = typename Cfg::Gemm; using StrideA = typename Cfg::StrideA; @@ -346,10 +344,10 @@ static void run_impl(const std::vector &a_data, // bias_ptr_array[g] already points at group g's length-N bias base. fusion_args = { {/*scalars=*/{}, /*scalar_ptrs=*/{}, /*scalar_ptr_arrays=*/{alpha_ptr_array_d}, - /*dScalar=*/{}}, // alpha[g] - {}, // acc - {bias_ptr_array_d, ElementBias(0), {}}, // bias[g] - {} // homogeneous_multiply_add + /*dScalar=*/{}}, // alpha[g] + {}, // acc + {bias_ptr_array_d, ElementBias(0), {}}, // bias[g] + {} // homogeneous_multiply_add }; } else { fusion_args.alpha = 1.0f; // overridden per-group by alpha_ptr_array @@ -387,25 +385,22 @@ static void run_impl(const std::vector &a_data, cutlassGetStatusString(status), " (num_groups=", G, ")"); status = gemm.initialize(arguments, workspace, stream); - NVTE_CHECK(status == cutlass::Status::kSuccess, - "CUTLASS NVFP4 grouped per-tensor GEMM initialize failed: ", - cutlassGetStatusString(status)); + NVTE_CHECK( + status == cutlass::Status::kSuccess, + "CUTLASS NVFP4 grouped per-tensor GEMM initialize failed: ", cutlassGetStatusString(status)); status = gemm.run(stream); NVTE_CHECK(status == cutlass::Status::kSuccess, "CUTLASS NVFP4 grouped per-tensor GEMM run failed: ", cutlassGetStatusString(status)); } -void run_grouped_per_tensor_gemm(const std::vector &a_data, - const std::vector &b_data, - const std::vector &a_sf, - const std::vector &b_sf, - const std::vector &alpha_ptrs, - const std::vector &d_ptrs, - const std::vector &bias_ptrs, - const std::vector &Ms, const std::vector &Ns, - const std::vector &Ks, bool fp32_output, bool accumulate, - cudaStream_t stream) { +void run_grouped_per_tensor_gemm( + const std::vector &a_data, const std::vector &b_data, + const std::vector &a_sf, const std::vector &b_sf, + const std::vector &alpha_ptrs, const std::vector &d_ptrs, + const std::vector &bias_ptrs, const std::vector &Ms, + const std::vector &Ns, const std::vector &Ks, bool fp32_output, bool accumulate, + cudaStream_t stream) { static const std::vector kNoBias; const bool has_bias = !bias_ptrs.empty(); if (has_bias) { @@ -421,16 +416,17 @@ void run_grouped_per_tensor_gemm(const std::vector &a_data, } else { NVTE_CHECK(!accumulate, "CUTLASS NVFP4 grouped per-tensor GEMM: accumulate requires FP32 output."); - run_impl(a_data, b_data, a_sf, b_sf, alpha_ptrs, - d_ptrs, kNoBias, Ms, Ns, Ks, accumulate, - stream); + run_impl( + a_data, b_data, a_sf, b_sf, alpha_ptrs, d_ptrs, kNoBias, Ms, Ns, Ks, accumulate, stream); } } #else // !CUTLASS_ARCH_MMA_SM100_SUPPORTED -void run_grouped_per_tensor_gemm(const std::vector &, const std::vector &, - const std::vector &, const std::vector &, +void run_grouped_per_tensor_gemm(const std::vector &, + const std::vector &, + const std::vector &, + const std::vector &, const std::vector &, const std::vector &, const std::vector &, const std::vector &, const std::vector &, const std::vector &, bool, bool, diff --git a/transformer_engine/common/gemm/nvfp4_cutlass_grouped_gemm.cuh b/transformer_engine/common/gemm/nvfp4_cutlass_grouped_gemm.cuh index 3a1af9df7d..3d832c450a 100644 --- a/transformer_engine/common/gemm/nvfp4_cutlass_grouped_gemm.cuh +++ b/transformer_engine/common/gemm/nvfp4_cutlass_grouped_gemm.cuh @@ -46,16 +46,13 @@ namespace nvfp4_cutlass { // points to group g's length-N bias vector (D-element dtype), added in the // epilogue (D[g] = alpha[g]*acc + bias[g]). Requires BF16 output and // overwrite (no accumulate). Pass an empty vector when there is no bias. -void run_grouped_per_tensor_gemm(const std::vector &a_data, - const std::vector &b_data, - const std::vector &a_sf, - const std::vector &b_sf, - const std::vector &alpha_ptrs, - const std::vector &d_ptrs, - const std::vector &bias_ptrs, - const std::vector &Ms, const std::vector &Ns, - const std::vector &Ks, bool fp32_output, bool accumulate, - cudaStream_t stream); +void run_grouped_per_tensor_gemm( + const std::vector &a_data, const std::vector &b_data, + const std::vector &a_sf, const std::vector &b_sf, + const std::vector &alpha_ptrs, const std::vector &d_ptrs, + const std::vector &bias_ptrs, const std::vector &Ms, + const std::vector &Ns, const std::vector &Ks, bool fp32_output, bool accumulate, + cudaStream_t stream); } // namespace nvfp4_cutlass } // namespace transformer_engine diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index 8a52dbff0f..894cf7c82e 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -674,8 +674,7 @@ void nvfp4_grouped_per_tensor_gemm(std::vector A, bool transa, // alpha. Scales of A/B must already be GEMM-swizzled by the caller (e.g. via // multi_tensor_swizzle_scales_for_gemm_). Only the GEMM launch happens here. const size_t num_gemms = A.size(); - NVTE_CHECK(B.size() == num_gemms && D.size() == num_gemms, - "A, B, D must have matching lengths."); + NVTE_CHECK(B.size() == num_gemms && D.size() == num_gemms, "A, B, D must have matching lengths."); const bool have_bias = !bias.empty(); NVTE_CHECK(!have_bias || bias.size() == num_gemms, "bias must be empty or have length num_gemms."); @@ -702,11 +701,11 @@ void nvfp4_grouped_per_tensor_gemm(std::vector A, bool transa, } NVTE_SCOPED_GIL_RELEASE({ - nvte_nvfp4_grouped_per_tensor_gemm( - te_A_vector.data(), te_B_vector.data(), te_D_vector.data(), - have_bias ? te_bias_vector.data() : nullptr, static_cast(num_gemms), transa, transb, - accumulate, reinterpret_cast(alpha.data_ptr()), - at::cuda::getCurrentCUDAStream()); + nvte_nvfp4_grouped_per_tensor_gemm(te_A_vector.data(), te_B_vector.data(), te_D_vector.data(), + have_bias ? te_bias_vector.data() : nullptr, + static_cast(num_gemms), transa, transb, accumulate, + reinterpret_cast(alpha.data_ptr()), + at::cuda::getCurrentCUDAStream()); }); } diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index e8249e5e13..2c21259e28 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -350,8 +350,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "Grouped GEMM"); m.def("nvfp4_grouped_per_tensor_compute_alpha", &transformer_engine::pytorch::nvfp4_grouped_per_tensor_compute_alpha, - "Bench-only: precompute per-group alpha for per-tensor NVFP4 grouped GEMM", - py::arg("A"), py::arg("transa"), py::arg("B"), py::arg("transb")); + "Bench-only: precompute per-group alpha for per-tensor NVFP4 grouped GEMM", py::arg("A"), + py::arg("transa"), py::arg("B"), py::arg("transb")); m.def("nvfp4_grouped_per_tensor_gemm", &transformer_engine::pytorch::nvfp4_grouped_per_tensor_gemm, "Bench-only: single-launch per-tensor NVFP4 grouped GEMM with precomputed alpha",