[Common/PyTorch] Grouped-quantize kernels for 1D and 2D FP8 block-scaling#3135
[Common/PyTorch] Grouped-quantize kernels for 1D and 2D FP8 block-scaling#3135denera wants to merge 2 commits into
Conversation
Implements grouped-tensor quantize for the FP8 1D (1x128) and 2D (128x128)
block-scaling recipes. A single CUDA kernel launch walks 128x128 tiles
across every tensor in the group, with each CTA decoding its owning
tensor from the device-side GroupedTensor metadata.
Supported shape representations:
- SAME_BOTH_DIMS (all tensors identical)
- VARYING_FIRST_DIM (constant K, varying R - the common MoE topology)
Supported directions: rowwise-only, columnwise-only, and both.
These kernels are gated to Hopper (sm_90) at the host dispatcher because
the consumer cuBLAS FP8 block-scaling *grouped* GEMM is itself
Hopper-only (cuBLAS does not provide native FP8 block-scaling grouped
GEMM on Blackwell; the recommended quantization recipe on Blackwell is
MXFP8). The device-side kernel bodies are gated on __CUDA_ARCH__ >= 900
so the kernels compile and link as part of multi-arch builds, but the
host gate prevents launches on Blackwell.
Three kernels share the dispatcher in
group_quantize_blockwise_{1d,2d}:
| Kernel | Dispatched when | Threading | Smem |
|--------|-----------------|-----------|------|
| group_block_scaled_1d_rw_kernel | 1D RW-only | 8 threads/row x 32 row-warps x 4 iters; reads gmem directly into vec-16 registers | none |
| group_block_scaled_1d_tma_kernel | 1D CW or 1D BOTH | TMA bulk-load fills 32 KB input cache. BOTH runs RW pass first (8 t/row, vec-16) then CW pass (2 t/col, 64-row register stage); CW-only skips the RW pass. CW writes the transposed-FP8 tile to a 16.5 KB smem_T staging buffer, then drains to gmem. | 32 KB + 16.5 KB |
| group_block_scaled_2d_tma_kernel | 2D RW / CW / BOTH | TMA bulk-load fills 32 KB cache. Pass 1 stages 8 IVecs/thread in registers while computing the per-tile scalar amax. Pass 2 quantizes from registers, emits rowwise output, stages columnwise output to smem_T, then drains. | 32 KB + 16.5 KB |
The RW-only 1D path bypasses TMA because a streaming read has no reuse
- the smem round-trip and mbarrier overhead would just add latency.
The C++ test tests/cpp/operator/test_cast_float8blockwise_grouped.cu
exercises 72 configurations covering RW/CW/BOTH x 1D/2D x SAME/VARYING
shape representations against a per-tensor split-quantize reference.
Signed-off-by: Alp Dener <adener@nvidia.com>
for more information, see https://pre-commit.ci
| constexpr int kThreadsPerBlock = 256; | ||
| constexpr int kNumWarps = kThreadsPerBlock / kThreadsPerWarp; | ||
|
|
||
| // Align a dynamic-smem pointer to 128 bytes (TMA requirement). |
There was a problem hiding this comment.
Could we reuse the existing align_smem_ptr_per_TMA_requirements() helper from transformer_engine/cast/core/common.h here?
| size_t total_row_blocks) { | ||
| using namespace transformer_engine::dispatch::mxfp8::swizzle; | ||
| const size_t num_tiles_X = | ||
| (total_row_blocks + GEMM_SWIZZLED_SCALE_TILE_DIM_X - 1) / GEMM_SWIZZLED_SCALE_TILE_DIM_X; |
There was a problem hiding this comment.
We can also reuse the existing DIVUP() helper here (defined in transformer_engin/common/common.h).
|
|
||
| // ---- Tensor-lookup helpers ---------------------------------------------------- | ||
|
|
||
| // Map a global tile-row index to its owning tensor by binary-searching |
There was a problem hiding this comment.
We can also reuse the existing get_current_tensor_id() helper defined in transformer_engine/cast/core/common.cuh
Greptile SummaryThis PR adds grouped-tensor FP8 quantize kernels for the 1D (1×128) and 2D (128×128) block-scaling recipes, dispatching a single CUDA kernel launch across all tensors in the group so each CTA decodes its owning tensor from device-side metadata. It also lowers several
Confidence Score: 3/5The new kernels are functionally correct for well-formed inputs, but the VARYING_FIRST_DIM path will silently produce incorrect output (skipping the last partial tile-row of any tensor) if a caller passes a per-tensor first dimension that is not a multiple of 128 — an alignment that SAME_BOTH_DIMS enforces explicitly but VARYING_FIRST_DIM does not. Three new CUDA kernels, a new ptx intrinsic, and plumbing through both the C++ and PyTorch dispatch layers are all working correctly for the validated inputs. The VARYING_FIRST_DIM dispatcher skips the per-tensor first-dim alignment check that the SAME_BOTH_DIMS path enforces, meaning out-of-contract callers would silently receive truncated (partially un-quantized) tensors with no error. The transformer_engine/common/cast/fp8_blockwise/group_quantize_fp8_blockwise.cuh — specifically the VARYING_FIRST_DIM branch of prepare_grouped_blockwise_launch (~line 655) and the 1D dispatcher error message (~line 747). Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["nvte_group_quantize / group_quantize (PyTorch)"] --> B{scaling_mode?}
B -->|BLOCK_SCALING_1D| C["group_quantize_blockwise_1d()"]
B -->|BLOCK_SCALING_2D| D["group_quantize_blockwise_2d()"]
C --> E{use_rowwise only?}
E -->|Yes| F["group_block_scaled_1d_rw_kernel\n(no smem cache, vec-16 gmem reads)"]
E -->|CW or BOTH| G["group_block_scaled_1d_tma_kernel\n(TMA bulk-load → smem cache)\nRW pass + CW pass with smem_T transpose"]
D --> H["group_block_scaled_2d_tma_kernel\n(TMA bulk-load → smem cache)\nPass 1: amax in registers\nPass 2: quantize + smem_T drain"]
F --> L{kSameBothDims?}
G --> L
H --> L
L -->|Yes| M["tensor_id = block_y / common_first_dim_blocks"]
L -->|VARYING_FIRST_DIM| N["binary search on device offsets\ntensor_block_y_base_from_offsets"]
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
flowchart TD
A["nvte_group_quantize / group_quantize (PyTorch)"] --> B{scaling_mode?}
B -->|BLOCK_SCALING_1D| C["group_quantize_blockwise_1d()"]
B -->|BLOCK_SCALING_2D| D["group_quantize_blockwise_2d()"]
C --> E{use_rowwise only?}
E -->|Yes| F["group_block_scaled_1d_rw_kernel\n(no smem cache, vec-16 gmem reads)"]
E -->|CW or BOTH| G["group_block_scaled_1d_tma_kernel\n(TMA bulk-load → smem cache)\nRW pass + CW pass with smem_T transpose"]
D --> H["group_block_scaled_2d_tma_kernel\n(TMA bulk-load → smem cache)\nPass 1: amax in registers\nPass 2: quantize + smem_T drain"]
F --> L{kSameBothDims?}
G --> L
H --> L
L -->|Yes| M["tensor_id = block_y / common_first_dim_blocks"]
L -->|VARYING_FIRST_DIM| N["binary search on device offsets\ntensor_block_y_base_from_offsets"]
Reviews (1): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile |
| } else { | ||
| info.common_first_dim_blocks = 0; | ||
| info.R_total = output->logical_shape.data[0]; | ||
| info.tensor_offsets_d = reinterpret_cast<const int64_t*>(output->tensor_offsets.dptr); | ||
| NVTE_CHECK(info.tensor_offsets_d != nullptr, | ||
| "VARYING_FIRST_DIM requires tensor_offsets to be set on the GroupedTensor."); | ||
| } | ||
| info.total_row_blocks = (info.R_total + kTileDim - 1) / kTileDim; | ||
| info.blocks_X = (info.K + kTileDim - 1) / kTileDim; | ||
| return info; | ||
| } |
There was a problem hiding this comment.
VARYING_FIRST_DIM path silently requires 128-aligned per-tensor first dims
The SAME_BOTH_DIMS branch (line 651) enforces common_first_dim % kTileDim == 0, but the VARYING_FIRST_DIM branch has no equivalent check. The kernel's correctness depends entirely on this alignment: tensor_block_y_base_from_offsets divides element offsets by kTileDim * K using integer truncation, and tensor_row_blocks is derived the same way. A tensor with first_dim = 192 (not a multiple of 128) would produce tensor_row_blocks = 1 instead of 2, causing the second 64-row slice (rows 128–191) to be silently skipped by the in-kernel bounds guard and left un-quantized. The offsets are device-resident so host validation isn't straightforward, but a prominent NVTE_CHECK comment or a note in the function contract would prevent silent data loss from callers with unexpected shapes.
| NVTE_CHECK(sm >= 90 && sm < 100, | ||
| "Grouped FP8 block-scaling quantize is only supported on Hopper (SM90); " | ||
| "use MXFP8 on Blackwell (SM100) or newer. Got SM", | ||
| sm, "."); |
There was a problem hiding this comment.
The error message in
group_quantize_blockwise_1d says "SM90" while the identical check in group_quantize_blockwise_2d correctly says "SM90-SM99". The condition sm >= 90 && sm < 100 covers the full Hopper range, so the 1D message is misleading.
| NVTE_CHECK(sm >= 90 && sm < 100, | |
| "Grouped FP8 block-scaling quantize is only supported on Hopper (SM90); " | |
| "use MXFP8 on Blackwell (SM100) or newer. Got SM", | |
| sm, "."); | |
| NVTE_CHECK(sm >= 90 && sm < 100, | |
| "Grouped FP8 block-scaling quantize is only supported on Hopper (SM90-SM99); " | |
| "use MXFP8 on Blackwell (SM100) or newer. Got SM", | |
| sm, "."); |
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
| if (first_dims_d) cudaFree(first_dims_d); | ||
| } | ||
|
|
||
| struct TestConfig { | ||
| ShapeRep shape_rep; | ||
| BlockDim block_dim; | ||
| ScalingDir dir; | ||
| std::vector<size_t> first_dims; | ||
| size_t K; | ||
| }; | ||
|
|
||
| class GroupedFP8BlockwiseTestSuite : public ::testing::TestWithParam<TestConfig> {}; | ||
|
|
||
| TEST_P(GroupedFP8BlockwiseTestSuite, Test) { | ||
| const TestConfig& cfg = GetParam(); | ||
| perform_test<bf16, fp8e4m3>(cfg.shape_rep, cfg.block_dim, cfg.dir, cfg.first_dims, cfg.K, | ||
| /*force_pow_2_scales=*/true, /*epsilon=*/0.0f); | ||
| } | ||
|
|
||
| std::vector<TestConfig> make_configs() { | ||
| std::vector<TestConfig> configs; | ||
| std::vector<std::vector<size_t>> uniform = {{128, 128}, {256, 256, 256, 256}}; | ||
| std::vector<std::vector<size_t>> jagged = { | ||
| {128, 256, 384, 512}, {256, 128, 512, 384, 1024}}; | ||
| std::vector<size_t> Ks = {128, 256, 512}; | ||
| for (auto bd : {BlockDim::ONE_D, BlockDim::TWO_D}) { | ||
| for (auto dir : {ScalingDir::ROWWISE, ScalingDir::COLWISE, ScalingDir::BOTH}) { | ||
| for (size_t K : Ks) { | ||
| for (const auto& v : uniform) { | ||
| configs.push_back({ShapeRep::SAME_BOTH_DIMS, bd, dir, v, K}); | ||
| } |
There was a problem hiding this comment.
Swizzled-scale path (
with_gemm_swizzled_scales=true) is not exercised
The host dispatchers plumb output->with_gemm_swizzled_scales into both the 1D and 2D kernels (the kSwizzled template parameter), and the swizzled-scale indexing in swizzled_colwise_scale_idx is a separate non-trivial code path. Neither make_configs() nor any test fixture sets this flag, so the swizzled layout is never compared against a reference. Since cuBLAS FP8 block-scaling GEMM is the primary consumer of the swizzled layout, a bug there would be invisible until GEMM produces wrong results.
|
|
||
| // ---- TMA async load of the input tile ---- | ||
| if (leading_thread) { | ||
| ptx::mbarrier_init(&tma_mbar, 1); |
There was a problem hiding this comment.
Since mbar resides in shared memory, a cross-proxy fence between the async and generic proxies needs to be issued here before __syncthreads() so that both the TMA engine and the threads observe mbar in the correct state. We can use ptx::fence_proxy_async_shared_cta() defined in transformer_engine/common/util/ptx.cuh.
| } | ||
|
|
||
| CType amax = compute_row_amax<IType, CType, kVec>(in_vec[it]); | ||
| amax = fmaxf(amax, __shfl_xor_sync(0xffffffff, amax, 1)); |
There was a problem hiding this comment.
Could we reuse the existing amax warp-reduction helpers (warp_reduce_max() or reduce_max()) from transformer_engine/common/utils.cuh here?
| // ---- TMA async load of the input tile ---- | ||
| if (leading_thread) { | ||
| ptx::mbarrier_init(&tma_mbar, 1); | ||
| } |
There was a problem hiding this comment.
Similar to the above:
| } | |
| ptx::fence_proxy_async_shared_cta(); | |
| } |
| amax = fmaxf(amax, __shfl_xor_sync(0xffffffff, amax, 1)); | ||
| amax = fmaxf(amax, __shfl_xor_sync(0xffffffff, amax, 2)); | ||
| amax = fmaxf(amax, __shfl_xor_sync(0xffffffff, amax, 4)); |
There was a problem hiding this comment.
We can also reuse reduce_max() or warp_reduce_max() here.
|
|
||
| // ----- Host-side dispatchers -------------------------------------------------------------------- | ||
|
|
||
| inline size_t align_up_to(size_t x, size_t a) { return ((x + a - 1) / a) * a; } |
There was a problem hiding this comment.
We can reuse DIVUP_TO_MULTIPLE() defined in transformer_engine/common/common.h.
| NVTE_CHECK(info.tensor_offsets_d != nullptr, | ||
| "VARYING_FIRST_DIM requires tensor_offsets to be set on the GroupedTensor."); | ||
| } | ||
| info.total_row_blocks = (info.R_total + kTileDim - 1) / kTileDim; |
There was a problem hiding this comment.
| info.total_row_blocks = (info.R_total + kTileDim - 1) / kTileDim; | |
| info.total_row_blocks = DIVUP(info.R_total, kTileDim); |
| "VARYING_FIRST_DIM requires tensor_offsets to be set on the GroupedTensor."); | ||
| } | ||
| info.total_row_blocks = (info.R_total + kTileDim - 1) / kTileDim; | ||
| info.blocks_X = (info.K + kTileDim - 1) / kTileDim; |
There was a problem hiding this comment.
| info.blocks_X = (info.K + kTileDim - 1) / kTileDim; | |
| info.blocks_X = DIVUP(info.K, kTileDim); |
| info.same_both_dims = same_both_dims; | ||
| info.num_tensors = output->num_tensors; | ||
| info.K = output->get_common_last_dim(); | ||
| NVTE_CHECK(info.K % 16 == 0, "Last dim must be multiple of 16 (FP8 alignment)."); |
There was a problem hiding this comment.
If this is a TMA requirement, we can use the TMA_GMEM_ALIGNMENT constant defined in transformer_engine/common/common.h
| const float* noop_ptr = | ||
| (noop != nullptr) ? reinterpret_cast<const float*>(noop->data.dptr) : nullptr; | ||
|
|
||
| const size_t scale_stride_y = align_up_to(info.blocks_X, 4); |
There was a problem hiding this comment.
| const size_t scale_stride_y = align_up_to(info.blocks_X, 4); | |
| const size_t scale_stride_y = DIVUP_TO_MULTIPLE(info.blocks_X, 4); |
| const size_t scale_stride_y = align_up_to(info.blocks_X, 4); | ||
| // CW scales are stored [blocks_X, align4(total_row_blocks)] -- transposed to | ||
| // match the physically-transposed columnwise data the TN cuBLAS GEMM consumes. | ||
| const size_t scale_t_stride_y = align_up_to(info.total_row_blocks, 4); |
There was a problem hiding this comment.
| const size_t scale_t_stride_y = align_up_to(info.total_row_blocks, 4); | |
| const size_t scale_t_stride_y = DIVUP_TO_MULTIPLE(info.total_row_blocks, 4); |
| const float* noop_ptr = | ||
| (noop != nullptr) ? reinterpret_cast<const float*>(noop->data.dptr) : nullptr; | ||
|
|
||
| const size_t scale_stride_aligned_R = align_up_to(info.R_total, 4); |
There was a problem hiding this comment.
| const size_t scale_stride_aligned_R = align_up_to(info.R_total, 4); | |
| const size_t scale_stride_aligned_R = DIVUP_TO_MULTIPLE(info.R_total, 4); |
| (noop != nullptr) ? reinterpret_cast<const float*>(noop->data.dptr) : nullptr; | ||
|
|
||
| const size_t scale_stride_aligned_R = align_up_to(info.R_total, 4); | ||
| const size_t scale_t_stride_aligned_K = align_up_to(info.K, 4); |
There was a problem hiding this comment.
| const size_t scale_t_stride_aligned_K = align_up_to(info.K, 4); | |
| const size_t scale_t_stride_aligned_K = DIVUP_TO_MULTIPLE(info.K, 4); |
| // ---- TMA async load of the input tile ---- | ||
| if (leading_thread) { | ||
| ptx::mbarrier_init(&tma_mbar, 1); | ||
| } |
There was a problem hiding this comment.
| } | |
| ptx::fence_proxy_async_shared_cta(); | |
| } |
Description
Implements grouped-tensor quantize for the FP8 1D (1x128) and 2D (128x128) block-scaling recipes in row-wise (RW), column-wise (CW) and BOTH quantization directions. A single CUDA kernel launch walks 128x128 tiles across every tensor in the group, with each CTA decoding its owning tensor from the device-side GroupedTensor metadata with (N, R, K) shapes. Supports
SAME_BOTH_DIMS(all tensors identical) andVARYING_FIRST_DIM(constant K, varying R) shape representations.Three kernels share the dispatcher in
group_quantize_blockwise_{1d,2d}:group_block_scaled_1d_rw_kernel— RW-only dispatch; 8 threads/row, reads global memory directly into vec-16 registers; bypasses TMA because the shared memory roundtrip andptx::mbarrierdoes not buy anything without re-use in CW path.group_block_scaled_1d_tma_kernel— CW-only and BOTH dispatch; TMA bulk-load fills shared memory input cache. BOTH runs RW pass first (8 threads/row, vec-16 read from shared memory) then CW pass (2 threads/column, 64-row register stage); CW-only skips the RW pass. CW path writes the transposed-FP8 tile to a shared memory transpose staging buffer, then drains to global memory.group_block_scaled_2d_tma_kernel— RW-only, CW-only and BOTH dispatch; TMA bulk-load fills shared memory input cache. Pass 1 stages 8 IVecs/thread in registers while computing the per-tile scalar amax. Pass 2 quantizes from registers, emits row-wise output, stages column-wise output to shared memory transpose staging buffer, then drains to global memory.Kernels are gated to Hopper (sm_90) at the host dispatcher (cuBlasLt grouped GEMM supports FP8 block-scaling only on Hopper).
PR includes PyTorch integration.
JAX integration is intentionally left out-of-scope and deferred to a follow-up PR because it requires non-trivial new scaffolding on the framework side.
Resolves #2525
Performance
Table below measures performance on H200 with a sweep of grouped tensors in (N, M, K) shapes with:
The shapes are split into two buckets:
Reported kernel times and throughput ratios are bucket medians.
Speedup is measured relative to the split-quantized fallback that loops over the grouped tensor and sequentially quantizes each one.
% of "mono" throughput is measured relative to the throughput of a single non-grouped FP8 block-scaling quantize kernel invoked with the equivalent monolithic (NxM, K) tensor where the # of experts are collapsed with # of tokens/expert.
Notes
Known Sub-Optimalities
1D CW has bank conflicts on ~35% of load wavefronts (reading from the shared memory input-cache)
CU_TENSOR_MAP_SWIZZLE_128Bhas the right pattern but caps FP16/BF16 at 64-elements; does not fit the 128-element tile for FP8 block-scaling without doubling per-tile launch overhead (quadrupling for FP32).1D BOTH reads the shared memory input-cache twice
2D CW/BOTH has bank conflicts on ~16% of store wavefronts (when writing to the shared memory transpose buffer)
No TMA-store
Type of change
Checklist: