From 02c87d6bb88dd3e2d47883aaba3640ff09323dff Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Wed, 17 Jun 2026 08:27:11 +0000 Subject: [PATCH 1/2] Add grouped FP8 block-scaling quantize kernels Implements grouped-tensor quantize for the FP8 1D (1x128) and 2D (128x128) block-scaling recipes. A single CUDA kernel launch walks 128x128 tiles across every tensor in the group, with each CTA decoding its owning tensor from the device-side GroupedTensor metadata. Supported shape representations: - SAME_BOTH_DIMS (all tensors identical) - VARYING_FIRST_DIM (constant K, varying R - the common MoE topology) Supported directions: rowwise-only, columnwise-only, and both. These kernels are gated to Hopper (sm_90) at the host dispatcher because the consumer cuBLAS FP8 block-scaling *grouped* GEMM is itself Hopper-only (cuBLAS does not provide native FP8 block-scaling grouped GEMM on Blackwell; the recommended quantization recipe on Blackwell is MXFP8). The device-side kernel bodies are gated on __CUDA_ARCH__ >= 900 so the kernels compile and link as part of multi-arch builds, but the host gate prevents launches on Blackwell. Three kernels share the dispatcher in group_quantize_blockwise_{1d,2d}: | Kernel | Dispatched when | Threading | Smem | |--------|-----------------|-----------|------| | group_block_scaled_1d_rw_kernel | 1D RW-only | 8 threads/row x 32 row-warps x 4 iters; reads gmem directly into vec-16 registers | none | | group_block_scaled_1d_tma_kernel | 1D CW or 1D BOTH | TMA bulk-load fills 32 KB input cache. BOTH runs RW pass first (8 t/row, vec-16) then CW pass (2 t/col, 64-row register stage); CW-only skips the RW pass. CW writes the transposed-FP8 tile to a 16.5 KB smem_T staging buffer, then drains to gmem. | 32 KB + 16.5 KB | | group_block_scaled_2d_tma_kernel | 2D RW / CW / BOTH | TMA bulk-load fills 32 KB cache. Pass 1 stages 8 IVecs/thread in registers while computing the per-tile scalar amax. Pass 2 quantizes from registers, emits rowwise output, stages columnwise output to smem_T, then drains. | 32 KB + 16.5 KB | The RW-only 1D path bypasses TMA because a streaming read has no reuse - the smem round-trip and mbarrier overhead would just add latency. The C++ test tests/cpp/operator/test_cast_float8blockwise_grouped.cu exercises 72 configurations covering RW/CW/BOTH x 1D/2D x SAME/VARYING shape representations against a per-tensor split-quantize reference. Signed-off-by: Alp Dener --- tests/cpp/operator/CMakeLists.txt | 1 + .../test_cast_float8blockwise_grouped.cu | 380 ++++++++ .../common/cast/dispatch/quantize.cuh | 31 + .../group_quantize_fp8_blockwise.cuh | 846 ++++++++++++++++++ transformer_engine/common/util/ptx.cuh | 48 +- .../pytorch/csrc/extensions/cast.cpp | 17 +- 6 files changed, 1307 insertions(+), 16 deletions(-) create mode 100644 tests/cpp/operator/test_cast_float8blockwise_grouped.cu create mode 100644 transformer_engine/common/cast/fp8_blockwise/group_quantize_fp8_blockwise.cuh diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index 9b67c09f34..918e842700 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -14,6 +14,7 @@ add_executable(test_operator test_cast_mxfp8_grouped.cu test_cast_nvfp4_transpose.cu test_cast_float8blockwise.cu + test_cast_float8blockwise_grouped.cu test_dequantize_mxfp8.cu test_dequantize_mxfp8_grouped.cu test_dequantize_nvfp4.cu diff --git a/tests/cpp/operator/test_cast_float8blockwise_grouped.cu b/tests/cpp/operator/test_cast_float8blockwise_grouped.cu new file mode 100644 index 0000000000..d701a0bd13 --- /dev/null +++ b/tests/cpp/operator/test_cast_float8blockwise_grouped.cu @@ -0,0 +1,380 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "../test_common.h" + +using namespace transformer_engine; +using namespace test; + +namespace { + +enum class ShapeRep { SAME_BOTH_DIMS = 0, VARYING_FIRST_DIM = 1 }; +enum class ScalingDir { ROWWISE = 0, COLWISE = 1, BOTH = 2 }; +enum class BlockDim { ONE_D = 1, TWO_D = 2 }; + +constexpr size_t kBlock = 128; + +inline size_t align4(size_t x) { return ((x + 3) / 4) * 4; } + +// Configure split-quantize reference: call non-grouped nvte_quantize_v2 on each tensor slice. +// Returns flat host buffers for per-tensor outputs and scales (in their per-tensor natural +// layout) so the test can index them and compare element-wise against the grouped layout. +struct PerTensorRef { + std::vector> output; // per tensor, FP8 raw bytes (R_t * K) + std::vector> output_t; // per tensor, FP8 raw bytes (K * R_t) + std::vector> scale_inv; // per tensor, layout per non-grouped impl + std::vector> scale_inv_t; // per tensor, layout per non-grouped impl +}; + +template +void perform_test(ShapeRep shape_rep, BlockDim block_dim, ScalingDir dir, + const std::vector& first_dims_h, size_t K, + bool force_pow_2_scales, float epsilon) { + if (getDeviceComputeCapability() < hopperComputeCapability) { + GTEST_SKIP(); + } + + DType itype = TypeInfo::dtype; + DType otype = TypeInfo::dtype; + + const size_t num_tensors = first_dims_h.size(); + size_t R_total = 0; + for (size_t m : first_dims_h) { + ASSERT_EQ(m % kBlock, 0u) << "Per-tensor first dim must be multiple of 128"; + R_total += m; + } + ASSERT_EQ(K % 16u, 0u); + + // Host data + std::mt19937 gen(0xC0FFEEu); + std::uniform_real_distribution dist(-2.0f, 1.0f); + std::vector input_h(R_total * K); + for (auto& v : input_h) v = static_cast(dist(gen)); + + // Tensor offsets (element offsets) + std::vector offsets_h(num_tensors + 1, 0); + for (size_t t = 0; t < num_tensors; ++t) { + offsets_h[t + 1] = offsets_h[t] + static_cast(first_dims_h[t] * K); + } + std::vector first_dims_i64(num_tensors); + for (size_t t = 0; t < num_tensors; ++t) first_dims_i64[t] = static_cast(first_dims_h[t]); + + const bool use_rowwise = (dir == ScalingDir::ROWWISE || dir == ScalingDir::BOTH); + const bool use_colwise = (dir == ScalingDir::COLWISE || dir == ScalingDir::BOTH); + + const NVTEScalingMode mode = + (block_dim == BlockDim::ONE_D) ? NVTE_BLOCK_SCALING_1D : NVTE_BLOCK_SCALING_2D; + + // Allocate grouped device buffers. + InputType* input_d = nullptr; + OutputType* output_d = nullptr; + OutputType* output_t_d = nullptr; + float* scale_inv_d = nullptr; + float* scale_inv_t_d = nullptr; + int64_t* offsets_d = nullptr; + int64_t* first_dims_d = nullptr; + + const size_t total_row_blocks = (R_total + kBlock - 1) / kBlock; + const size_t blocks_X = (K + kBlock - 1) / kBlock; + + size_t scale_inv_elems = 0; + size_t scale_inv_t_elems = 0; + std::vector scale_inv_shape, scale_inv_t_shape; + if (block_dim == BlockDim::ONE_D) { + // Rowwise: (blocks_X, align4(R_total)) + scale_inv_shape = {blocks_X, align4(R_total)}; + scale_inv_t_shape = {total_row_blocks, align4(K)}; + } else { + scale_inv_shape = {total_row_blocks, align4(blocks_X)}; + scale_inv_t_shape = {blocks_X, align4(total_row_blocks)}; + } + scale_inv_elems = scale_inv_shape[0] * scale_inv_shape[1]; + scale_inv_t_elems = scale_inv_t_shape[0] * scale_inv_t_shape[1]; + + const size_t input_bytes = R_total * K * sizeof(InputType); + const size_t output_bytes = R_total * K * sizeof(OutputType); + + cudaMalloc(&input_d, input_bytes); + cudaMemcpy(input_d, input_h.data(), input_bytes, cudaMemcpyHostToDevice); + cudaMalloc(&offsets_d, (num_tensors + 1) * sizeof(int64_t)); + cudaMemcpy(offsets_d, offsets_h.data(), (num_tensors + 1) * sizeof(int64_t), + cudaMemcpyHostToDevice); + if (shape_rep == ShapeRep::VARYING_FIRST_DIM) { + cudaMalloc(&first_dims_d, num_tensors * sizeof(int64_t)); + cudaMemcpy(first_dims_d, first_dims_i64.data(), num_tensors * sizeof(int64_t), + cudaMemcpyHostToDevice); + } + if (use_rowwise) { + cudaMalloc(&output_d, output_bytes); + cudaMemset(output_d, 0, output_bytes); + cudaMalloc(&scale_inv_d, scale_inv_elems * sizeof(float)); + cudaMemset(scale_inv_d, 0, scale_inv_elems * sizeof(float)); + } + if (use_colwise) { + cudaMalloc(&output_t_d, output_bytes); + cudaMemset(output_t_d, 0, output_bytes); + cudaMalloc(&scale_inv_t_d, scale_inv_t_elems * sizeof(float)); + cudaMemset(scale_inv_t_d, 0, scale_inv_t_elems * sizeof(float)); + } + + // Build grouped tensors. + std::vector logical_shape_vec = {R_total, K}; + NVTEShape logical_shape = nvte_make_shape(logical_shape_vec.data(), logical_shape_vec.size()); + + NVTEGroupedTensor in_gt = nvte_create_grouped_tensor(NVTE_DELAYED_TENSOR_SCALING, num_tensors, + logical_shape); + NVTEGroupedTensor out_gt = nvte_create_grouped_tensor(mode, num_tensors, logical_shape); + + NVTEBasicTensor in_data = {input_d, static_cast(itype), logical_shape}; + nvte_set_grouped_tensor_param(in_gt, kNVTEGroupedRowwiseData, &in_data, sizeof(in_data)); + + NVTEShape offsets_shape; + offsets_shape.ndim = 1; + offsets_shape.data[0] = num_tensors + 1; + NVTEBasicTensor offsets_bt = {offsets_d, kNVTEInt64, offsets_shape}; + if (shape_rep == ShapeRep::VARYING_FIRST_DIM) { + NVTEShape first_dims_shape; + first_dims_shape.ndim = 1; + first_dims_shape.data[0] = num_tensors; + NVTEBasicTensor first_dims_bt = {first_dims_d, kNVTEInt64, first_dims_shape}; + nvte_set_grouped_tensor_param(in_gt, kNVTEGroupedFirstDims, &first_dims_bt, + sizeof(first_dims_bt)); + nvte_set_grouped_tensor_param(out_gt, kNVTEGroupedFirstDims, &first_dims_bt, + sizeof(first_dims_bt)); + nvte_set_grouped_tensor_param(in_gt, kNVTEGroupedTensorOffsets, &offsets_bt, + sizeof(offsets_bt)); + nvte_set_grouped_tensor_param(out_gt, kNVTEGroupedTensorOffsets, &offsets_bt, + sizeof(offsets_bt)); + } + + if (use_rowwise) { + NVTEBasicTensor out_data = {output_d, static_cast(otype), logical_shape}; + NVTEShape scale_inv_shape_nv = nvte_make_shape(scale_inv_shape.data(), scale_inv_shape.size()); + NVTEBasicTensor scale_bt = {scale_inv_d, kNVTEFloat32, scale_inv_shape_nv}; + nvte_set_grouped_tensor_param(out_gt, kNVTEGroupedRowwiseData, &out_data, sizeof(out_data)); + nvte_set_grouped_tensor_param(out_gt, kNVTEGroupedRowwiseScaleInv, &scale_bt, sizeof(scale_bt)); + } + if (use_colwise) { + NVTEBasicTensor out_t_data = {output_t_d, static_cast(otype), logical_shape}; + NVTEShape scale_inv_t_shape_nv = nvte_make_shape(scale_inv_t_shape.data(), + scale_inv_t_shape.size()); + NVTEBasicTensor scale_t_bt = {scale_inv_t_d, kNVTEFloat32, scale_inv_t_shape_nv}; + nvte_set_grouped_tensor_param(out_gt, kNVTEGroupedColumnwiseData, &out_t_data, + sizeof(out_t_data)); + nvte_set_grouped_tensor_param(out_gt, kNVTEGroupedColumnwiseScaleInv, &scale_t_bt, + sizeof(scale_t_bt)); + } + + // Run grouped quantize. + QuantizationConfigWrapper quant_config; + quant_config.set_force_pow_2_scales(force_pow_2_scales); + quant_config.set_amax_epsilon(epsilon); + nvte_group_quantize(in_gt, out_gt, quant_config, 0); + cudaDeviceSynchronize(); + ASSERT_EQ(cudaGetLastError(), cudaSuccess); + + // Pull grouped outputs back to host. + std::vector output_h(use_rowwise ? R_total * K : 0); + std::vector output_t_h(use_colwise ? R_total * K : 0); + std::vector scale_inv_h(use_rowwise ? scale_inv_elems : 0); + std::vector scale_inv_t_h(use_colwise ? scale_inv_t_elems : 0); + if (use_rowwise) { + cudaMemcpy(output_h.data(), output_d, R_total * K, cudaMemcpyDeviceToHost); + cudaMemcpy(scale_inv_h.data(), scale_inv_d, scale_inv_elems * sizeof(float), + cudaMemcpyDeviceToHost); + } + if (use_colwise) { + cudaMemcpy(output_t_h.data(), output_t_d, R_total * K, cudaMemcpyDeviceToHost); + cudaMemcpy(scale_inv_t_h.data(), scale_inv_t_d, scale_inv_t_elems * sizeof(float), + cudaMemcpyDeviceToHost); + } + + // Run split-quantize reference per tensor and compare element-wise. + for (size_t t = 0; t < num_tensors; ++t) { + const size_t M = first_dims_h[t]; + const size_t row_offset = static_cast(offsets_h[t]) / K; + + std::vector tshape = {M, K}; + Tensor ref_in("ref_in_" + std::to_string(t), tshape, itype); + // The non-grouped 2D kernel requires rowwise output to be allocated even when only colwise + // data is consumed. We always allocate both and compare only what the grouped kernel produced. + const bool ref_rowwise = (block_dim == BlockDim::TWO_D) ? true : use_rowwise; + const bool ref_colwise = use_colwise; + Tensor ref_out("ref_out_" + std::to_string(t), tshape, otype, ref_rowwise, ref_colwise, mode); + + // Copy this tensor's input slice into ref_in. + { + auto* dst = ref_in.rowwise_dptr(); + const InputType* src = reinterpret_cast(input_d) + row_offset * K; + cudaMemcpy(dst, src, M * K * sizeof(InputType), cudaMemcpyDeviceToDevice); + } + + QuantizationConfigWrapper qc; + qc.set_force_pow_2_scales(force_pow_2_scales); + qc.set_amax_epsilon(epsilon); + nvte_quantize_v2(ref_in.data(), ref_out.data(), qc, 0); + cudaDeviceSynchronize(); + ASSERT_EQ(cudaGetLastError(), cudaSuccess); + ref_out.to_cpu(); // sync output and scale_inv buffers from GPU to CPU + + // Compare data. + if (use_rowwise) { + const OutputType* ref_data = ref_out.rowwise_cpu_dptr(); + for (size_t r = 0; r < M; ++r) { + for (size_t c = 0; c < K; ++c) { + const uint8_t got = output_h[(row_offset + r) * K + c]; + const uint8_t exp = reinterpret_cast(ref_data)[r * K + c]; + ASSERT_EQ(got, exp) << "rowwise data mismatch t=" << t << " r=" << r << " c=" << c; + } + } + } + if (use_colwise) { + const OutputType* ref_data_t = ref_out.columnwise_cpu_dptr(); + for (size_t c = 0; c < K; ++c) { + for (size_t r = 0; r < M; ++r) { + const uint8_t got = output_t_h[c * R_total + (row_offset + r)]; + const uint8_t exp = reinterpret_cast(ref_data_t)[c * M + r]; + ASSERT_EQ(got, exp) << "colwise data mismatch t=" << t << " c=" << c << " r=" << r; + } + } + } + + // Compare scales. + if (block_dim == BlockDim::ONE_D) { + const size_t M_pad = align4(M); + const size_t R_pad = align4(R_total); + const size_t K_pad = align4(K); + const size_t blocks_y_per_tensor = M / kBlock; + const size_t row_block_offset_t = row_offset / kBlock; + if (use_rowwise) { + const float* ref_sc = ref_out.rowwise_cpu_scale_inv_ptr(); + for (size_t bx = 0; bx < blocks_X; ++bx) { + for (size_t r = 0; r < M; ++r) { + const float got = scale_inv_h[bx * R_pad + (row_offset + r)]; + const float exp = ref_sc[bx * M_pad + r]; + ASSERT_EQ(got, exp) << "1D rowwise scale mismatch t=" << t << " bx=" << bx + << " r=" << r; + } + } + } + if (use_colwise) { + const float* ref_sct = ref_out.columnwise_cpu_scale_inv_ptr(); + for (size_t by = 0; by < blocks_y_per_tensor; ++by) { + for (size_t c = 0; c < K; ++c) { + const float got = scale_inv_t_h[(row_block_offset_t + by) * K_pad + c]; + const float exp = ref_sct[by * K_pad + c]; + ASSERT_EQ(got, exp) << "1D colwise scale mismatch t=" << t << " by=" << by + << " c=" << c; + } + } + } + } else { + // 2D: rowwise shape (blocks_y_total, align4(blocks_X)); per-tensor shape (M/128, + // align4(blocks_X)). Per-tensor block-row offset = M_block_off. + const size_t blocks_y_per_tensor = M / kBlock; + const size_t row_block_offset_t = row_offset / kBlock; + const size_t bx_pad = align4(blocks_X); + const size_t by_pad_total = align4(total_row_blocks); + const size_t by_pad_t = align4(blocks_y_per_tensor); + if (use_rowwise) { + const float* ref_sc = ref_out.rowwise_cpu_scale_inv_ptr(); + for (size_t by = 0; by < blocks_y_per_tensor; ++by) { + for (size_t bx = 0; bx < blocks_X; ++bx) { + const float got = scale_inv_h[(row_block_offset_t + by) * bx_pad + bx]; + const float exp = ref_sc[by * bx_pad + bx]; + ASSERT_EQ(got, exp) << "2D rowwise scale mismatch t=" << t << " by=" << by + << " bx=" << bx; + } + } + } + if (use_colwise) { + const float* ref_sct = ref_out.columnwise_cpu_scale_inv_ptr(); + for (size_t bx = 0; bx < blocks_X; ++bx) { + for (size_t by = 0; by < blocks_y_per_tensor; ++by) { + const float got = scale_inv_t_h[bx * by_pad_total + (row_block_offset_t + by)]; + const float exp = ref_sct[bx * by_pad_t + by]; + ASSERT_EQ(got, exp) << "2D colwise scale mismatch t=" << t << " bx=" << bx + << " by=" << by; + } + } + } + } + } + + nvte_destroy_grouped_tensor(in_gt); + nvte_destroy_grouped_tensor(out_gt); + cudaFree(input_d); + if (output_d) cudaFree(output_d); + if (output_t_d) cudaFree(output_t_d); + if (scale_inv_d) cudaFree(scale_inv_d); + if (scale_inv_t_d) cudaFree(scale_inv_t_d); + cudaFree(offsets_d); + if (first_dims_d) cudaFree(first_dims_d); +} + +struct TestConfig { + ShapeRep shape_rep; + BlockDim block_dim; + ScalingDir dir; + std::vector first_dims; + size_t K; +}; + +class GroupedFP8BlockwiseTestSuite : public ::testing::TestWithParam {}; + +TEST_P(GroupedFP8BlockwiseTestSuite, Test) { + const TestConfig& cfg = GetParam(); + perform_test(cfg.shape_rep, cfg.block_dim, cfg.dir, cfg.first_dims, cfg.K, + /*force_pow_2_scales=*/true, /*epsilon=*/0.0f); +} + +std::vector make_configs() { + std::vector configs; + std::vector> uniform = {{128, 128}, {256, 256, 256, 256}}; + std::vector> jagged = { + {128, 256, 384, 512}, {256, 128, 512, 384, 1024}}; + std::vector Ks = {128, 256, 512}; + for (auto bd : {BlockDim::ONE_D, BlockDim::TWO_D}) { + for (auto dir : {ScalingDir::ROWWISE, ScalingDir::COLWISE, ScalingDir::BOTH}) { + for (size_t K : Ks) { + for (const auto& v : uniform) { + configs.push_back({ShapeRep::SAME_BOTH_DIMS, bd, dir, v, K}); + } + for (const auto& v : jagged) { + configs.push_back({ShapeRep::VARYING_FIRST_DIM, bd, dir, v, K}); + } + } + } + } + return configs; +} + +std::string make_name(const ::testing::TestParamInfo& info) { + const auto& c = info.param; + std::string s = (c.shape_rep == ShapeRep::SAME_BOTH_DIMS ? "SAME" : "VARYFIRST"); + s += "_BD" + std::to_string(static_cast(c.block_dim)); + s += (c.dir == ScalingDir::ROWWISE ? "_RW" + : c.dir == ScalingDir::COLWISE ? "_CW" : "_BOTH"); + s += "_K" + std::to_string(c.K) + "_N" + std::to_string(c.first_dims.size()); + s += "_M"; + for (size_t m : c.first_dims) s += "_" + std::to_string(m); + return s; +} + +INSTANTIATE_TEST_SUITE_P(GroupedFP8Blockwise, GroupedFP8BlockwiseTestSuite, + ::testing::ValuesIn(make_configs()), make_name); + +} // namespace diff --git a/transformer_engine/common/cast/dispatch/quantize.cuh b/transformer_engine/common/cast/dispatch/quantize.cuh index 6c71285cd4..23f5708a64 100644 --- a/transformer_engine/common/cast/dispatch/quantize.cuh +++ b/transformer_engine/common/cast/dispatch/quantize.cuh @@ -18,6 +18,7 @@ #include "../../util/vectorized_pointwise.h" #include "../core/common.cuh" #include "../fp8/quantize_fp8.cuh" +#include "../fp8_blockwise/group_quantize_fp8_blockwise.cuh" #include "../mxfp8/group_quantize_mxfp8.cuh" #include "../mxfp8/quantize_mxfp8.cuh" #include "../nvfp4/group_quantize_transpose_nvfp4.cuh" @@ -466,6 +467,20 @@ void group_quantize_fwd_helper(const NVTEGroupedTensor input, NVTEGroupedTensor workspace_tensor, &quant_config_cpp, stream); break; } + case NVTE_BLOCK_SCALING_1D: { + NVTE_CHECK(!IS_ACT, "IS_ACT is not implemented for grouped NVTE_BLOCK_SCALING_1D."); + fp8_blockwise::group_quantize_blockwise_1d(input_tensor, output_tensor, noop_tensor, + quant_config_cpp.amax_epsilon, + quant_config_cpp.force_pow_2_scales, stream); + break; + } + case NVTE_BLOCK_SCALING_2D: { + NVTE_CHECK(!IS_ACT, "IS_ACT is not implemented for grouped NVTE_BLOCK_SCALING_2D."); + fp8_blockwise::group_quantize_blockwise_2d(input_tensor, output_tensor, noop_tensor, + quant_config_cpp.amax_epsilon, + quant_config_cpp.force_pow_2_scales, stream); + break; + } default: NVTE_ERROR("Not implemented scaling mode: " + to_string(scaling_mode) + "."); } @@ -507,6 +522,22 @@ void group_quantize_bwd_helper(const NVTEGroupedTensor grad, const NVTEGroupedTe &quant_config_cpp, stream); break; } + case NVTE_BLOCK_SCALING_1D: { + NVTE_CHECK((!IS_DBIAS && !IS_DACT), + "IS_DBIAS and IS_DACT are not implemented for grouped NVTE_BLOCK_SCALING_1D."); + fp8_blockwise::group_quantize_blockwise_1d(grad_tensor, output_tensor, noop_tensor, + quant_config_cpp.amax_epsilon, + quant_config_cpp.force_pow_2_scales, stream); + break; + } + case NVTE_BLOCK_SCALING_2D: { + NVTE_CHECK((!IS_DBIAS && !IS_DACT), + "IS_DBIAS and IS_DACT are not implemented for grouped NVTE_BLOCK_SCALING_2D."); + fp8_blockwise::group_quantize_blockwise_2d(grad_tensor, output_tensor, noop_tensor, + quant_config_cpp.amax_epsilon, + quant_config_cpp.force_pow_2_scales, stream); + break; + } default: NVTE_ERROR("Not implemented scaling mode: " + to_string(scaling_mode) + "."); } diff --git a/transformer_engine/common/cast/fp8_blockwise/group_quantize_fp8_blockwise.cuh b/transformer_engine/common/cast/fp8_blockwise/group_quantize_fp8_blockwise.cuh new file mode 100644 index 0000000000..b7b49ebd9a --- /dev/null +++ b/transformer_engine/common/cast/fp8_blockwise/group_quantize_fp8_blockwise.cuh @@ -0,0 +1,846 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file group_quantize_fp8_blockwise.cuh + * \brief CUDA kernels to quantize grouped tensors with FP8 1D and 2D + * block scaling. A single launch walks 128x128 tiles across every tensor + * in the group, with each CTA decoding its owning tensor from the device-side + * GroupedTensor metadata. Supports SAME_BOTH_DIMS and VARYING_FIRST_DIM. + * With with_gemm_swizzled_scales, colwise scales are written in the + * MXFP8-compatible GEMM-swizzled layout for cuBLAS FP8 block-scaling GEMM. + */ + +#ifndef TRANSFORMER_ENGINE_GROUP_QUANTIZE_FP8_BLOCKWISE_CUH_ +#define TRANSFORMER_ENGINE_GROUP_QUANTIZE_FP8_BLOCKWISE_CUH_ + +#include +#include +#include +#include + +#include + +#include "../../common.h" +#include "../../recipe/recipe_common.cuh" +#include "../../transpose/cast_transpose.h" +#include "../../util/cuda_runtime.h" +#include "../../util/ptx.cuh" +#include "../../utils.cuh" +#include "../mxfp8/swizzle.cuh" + +namespace transformer_engine { +namespace dispatch { +namespace fp8_blockwise { + +using transformer_engine::detail::FP8BlockwiseColumnwiseOption; +using transformer_engine::detail::FP8BlockwiseRowwiseOption; + +constexpr int kTileDim = 128; +constexpr int kThreadsPerWarp = 32; +constexpr int kThreadsPerBlock = 256; +constexpr int kNumWarps = kThreadsPerBlock / kThreadsPerWarp; + +// Align a dynamic-smem pointer to 128 bytes (TMA requirement). +__device__ __forceinline__ unsigned char* align_smem_128(unsigned char* p) { + return reinterpret_cast((reinterpret_cast(p) + 127ULL) & ~127ULL); +} + +// ---- Swizzled scale index helper ----------------------------------------------- + +// Computes the gemm-swizzled colwise scale index for FP8 block-scaling. +// Follows the same convention as MXFP8 gemm_swizzled_scale_idx but applied to +// FP32 1D/2D block-scaling colwise scales: +// i = column index (per-column scales for 1D; tile_x for 2D) +// j = row-block index (tile_y_global) +// num_tiles_X = DIVUP(total_row_blocks, TILE_DIM_X=4) +__device__ __forceinline__ size_t swizzled_colwise_scale_idx(size_t i, size_t j, + size_t total_row_blocks) { + using namespace transformer_engine::dispatch::mxfp8::swizzle; + const size_t num_tiles_X = (total_row_blocks + GEMM_SWIZZLED_SCALE_TILE_DIM_X - 1) / + GEMM_SWIZZLED_SCALE_TILE_DIM_X; + return gemm_swizzled_scale_idx(i, j, num_tiles_X); +} + +// ---- Tensor-lookup helpers ---------------------------------------------------- + +// Map a global tile-row index to its owning tensor by binary-searching +// `tensor_offsets_ptr` (element offsets) directly on device. +template +__device__ __forceinline__ size_t find_tensor_id_by_block_y( + const size_t block_y_global, const size_t num_tensors, + const size_t common_first_dim_blocks, const size_t tile_row_stride, + const int64_t* __restrict__ tensor_offsets_ptr) { + if constexpr (kSameBothDims) { + return block_y_global / common_first_dim_blocks; + } + const size_t target_elt = block_y_global * tile_row_stride; + size_t lo = 1; + size_t hi = num_tensors; + while (lo < hi) { + const size_t mid = lo + (hi - lo) / 2; + if (static_cast(tensor_offsets_ptr[mid]) <= target_elt) { + lo = mid + 1; + } else { + hi = mid; + } + } + return lo - 1; +} + +// Per-tensor block-y base for VARYING_FIRST_DIM (in 128-row block units). +__device__ __forceinline__ size_t tensor_block_y_base_from_offsets( + const size_t tensor_id, const int64_t* __restrict__ tensor_offsets_ptr, + const size_t tile_row_stride) { + return static_cast(tensor_offsets_ptr[tensor_id]) / tile_row_stride; +} + +// Per-vector amax. Uses bf16x2 `max.xorsign.abs` on sm_89+; FP32 fallback otherwise. +template +__device__ __forceinline__ CType compute_row_amax(const Vec& v) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 890) + if constexpr (std::is_same_v) { + static_assert(kVec % 2 == 0, "kVec must be even for packed bf16x2 amax"); + const ptx::bf16x2* pairs = reinterpret_cast(&v.data.elt[0]); + ptx::bf16x2 amax_x2{static_cast(0.f), static_cast(0.f)}; +#pragma unroll + for (int p = 0; p < kVec / 2; ++p) { + ptx::abs_max_2x(amax_x2, amax_x2, pairs[p]); + } + return static_cast(__hmax(__habs(amax_x2.x), __habs(amax_x2.y))); + } +#endif + CType amax = 0.f; +#pragma unroll + for (int e = 0; e < kVec; ++e) { + amax = fmaxf(amax, fabsf(static_cast(v.data.elt[e]))); + } + return amax; +} + +// Per-vector multiply-and-quantize via fp32 intermediates. +template +__device__ __forceinline__ void quantize_row_vec(Vec& out, + const Vec& in, CType scale) { +#pragma unroll + for (int e = 0; e < kVec; ++e) { + out.data.elt[e] = static_cast(static_cast(in.data.elt[e]) * scale); + } +} + +// Bank-conflict swizzle delta for the 2D smem_T staging buffer. delta carries +// bits 2..5 only so each 4-byte sub-chunk is preserved. +__device__ __forceinline__ int smem_t_swz_delta(int smem_t_row) { + return ((smem_t_row >> 3) & 0xf) << 2; +} + +// Drain smem_T to gmem for the 1D CW path: 4 cols/warp, 8 lanes/col, each lane +// stores a 16-row chunk so the 8 lanes of a col emit one 128 B gmem line. +template +__device__ __forceinline__ void drain_smem_t_1d_to_gmem( + OType (&smem_T)[kTileDim][kSMemTRowStride], + OType* __restrict__ output_t_base, const size_t global_col_base, + const size_t global_row_base, const size_t R_total, const size_t K, const int tid) { + constexpr int kStorePerChunk = 16; + constexpr int kRowChunksPerCol = kTileDim / kStorePerChunk; // 8 + constexpr int kColsPerIter = kThreadsPerBlock / kRowChunksPerCol; // 32 + constexpr int kColIters = kTileDim / kColsPerIter; // 4 + const int warp_id = tid / kThreadsPerWarp; + const int lane = tid % kThreadsPerWarp; + const int col_in_warp = lane / kRowChunksPerCol; + const int row_chunk = lane % kRowChunksPerCol; + const int out_row_off = row_chunk * kStorePerChunk; +#pragma unroll + for (int it = 0; it < kColIters; ++it) { + const int out_col_local = it * kColsPerIter + warp_id * 4 + col_in_warp; + const size_t out_col_global = global_col_base + out_col_local; + if (out_col_global < K) { + OType* out_ptr = + output_t_base + out_col_global * R_total + global_row_base + out_row_off; + Vec v; +#pragma unroll + for (int e = 0; e < kStorePerChunk; ++e) { + v.data.elt[e] = smem_T[out_col_local][out_row_off + e]; + } + v.store_to(out_ptr); + } + } +} + +// ----- 2D block scaling kernel with TMA input ------------------------------------------------ +// Pass 1: amax over a 128x128 TMA-loaded tile, with input vectors staged in +// registers. Pass 2: quantize from registers, emit rowwise output and the +// transposed smem_T tile, then drain smem_T to gmem. + +template +__global__ void __launch_bounds__(kThreadsPerBlock, 4) group_block_scaled_2d_tma_kernel( + const __grid_constant__ CUtensorMap tensor_map_input, OType* __restrict__ output_base, + OType* __restrict__ output_t_base, CType* __restrict__ scale_inv_base, + CType* __restrict__ scale_inv_t_base, const int64_t* __restrict__ tensor_offsets_ptr, + const size_t num_tensors, const size_t common_first_dim_blocks, const size_t K, + const size_t total_row_blocks, const size_t blocks_X, const size_t scale_stride_y, + const size_t scale_t_stride_y, const size_t R_total, const float epsilon, + const bool pow_2_scaling, const float* __restrict__ noop_ptr) { +#if __CUDA_ARCH__ >= 900 && __CUDA_ARCH__ < 1000 + if (noop_ptr != nullptr && noop_ptr[0] == 1.0f) return; + + const size_t tile_x = blockIdx.x; + const size_t tile_y_global = blockIdx.y; + if (tile_y_global >= total_row_blocks) return; + + const size_t tile_row_stride = static_cast(kTileDim) * K; + const size_t tensor_id = find_tensor_id_by_block_y( + tile_y_global, num_tensors, common_first_dim_blocks, tile_row_stride, tensor_offsets_ptr); + const size_t tensor_block_y_base = + kSameBothDims ? (tensor_id * common_first_dim_blocks) + : tensor_block_y_base_from_offsets(tensor_id, tensor_offsets_ptr, + tile_row_stride); + const size_t tensor_row_blocks = + kSameBothDims + ? common_first_dim_blocks + : (tensor_block_y_base_from_offsets(tensor_id + 1, tensor_offsets_ptr, tile_row_stride) - + tensor_block_y_base); + if (tile_y_global >= tensor_block_y_base + tensor_row_blocks) return; + + const size_t global_row_base = tile_y_global * kTileDim; + const size_t global_col_base = tile_x * kTileDim; + + // Dynamic smem holds the IType input tile (TMA dest, must be 128 B aligned). + // warp_amaxes and tma_mbar are static smem. + extern __shared__ unsigned char smem_raw_2d_tma[]; + IType (*smem_in)[kTileDim] = + reinterpret_cast(align_smem_128(smem_raw_2d_tma)); + + __shared__ CType warp_amaxes[kNumWarps]; + __shared__ uint64_t tma_mbar; + + const int tid = threadIdx.x; + const bool leading_thread = (tid == 0); + + // ---- TMA async load of the input tile ---- + if (leading_thread) { + ptx::mbarrier_init(&tma_mbar, 1); + } + __syncthreads(); + if (leading_thread) { + constexpr uint32_t tx_bytes = kTileDim * kTileDim * sizeof(IType); + ptx::mbarrier_arrive_expect_tx(&tma_mbar, tx_bytes); + ptx::cp_async_bulk_tensor_2d_global_to_shared_cta( + reinterpret_cast(smem_in), reinterpret_cast(&tensor_map_input), + static_cast(global_col_base), static_cast(global_row_base), &tma_mbar); + } + ptx::mbarrier_wait_parity(&tma_mbar, 0); + if (leading_thread) ptx::mbarrier_invalid(&tma_mbar); + __syncthreads(); + + // ---- Pass 1: tile amax, staging input vectors in registers for reuse in pass 2 ---- + constexpr int kEltsPerThread = 8; + constexpr int kThreadsPerRow = kTileDim / kEltsPerThread; // 16 + constexpr int kRowsPerIter = kThreadsPerBlock / kThreadsPerRow; // 16 + constexpr int kIters = kTileDim / kRowsPerIter; // 8 + + const int thr_col = tid % kThreadsPerRow; + const int thr_row = tid / kThreadsPerRow; + + using IVec = Vec; + using OVec = Vec; + + IVec staged[kIters]; + CType thr_amax = 0.f; +#pragma unroll + for (int it = 0; it < kIters; ++it) { + const int r_local = thr_row + it * kRowsPerIter; + staged[it].load_from(&smem_in[r_local][thr_col * kEltsPerThread]); + thr_amax = fmaxf(thr_amax, compute_row_amax(staged[it])); + } + CType warp_amax = warp_reduce_max(thr_amax); + const int warp_id = tid / kThreadsPerWarp; + const int lane = tid % kThreadsPerWarp; + if (lane == 0) warp_amaxes[warp_id] = warp_amax; + __syncthreads(); + + CType block_amax = warp_amaxes[0]; +#pragma unroll + for (int w = 1; w < kNumWarps; ++w) { + block_amax = fmaxf(block_amax, warp_amaxes[w]); + } + const CType scale = compute_scale_from_types(block_amax, epsilon, pow_2_scaling); + + if (leading_thread) { + const CType scale_inv = 1.f / scale; + if constexpr (kReturnRowwise) { + scale_inv_base[tile_y_global * scale_stride_y + tile_x] = scale_inv; + } + if constexpr (kReturnColwise) { + if constexpr (kSwizzledScales) { + scale_inv_t_base[swizzled_colwise_scale_idx(tile_x, tile_y_global, total_row_blocks)] = + scale_inv; + } else { + // Transposed CW: [blocks_X, total_row_blocks_padded]. Host passes + // scale_t_stride_y = align4(total_row_blocks). + scale_inv_t_base[tile_x * scale_t_stride_y + tile_y_global] = scale_inv; + } + } + } + + // ---- Pass 2: quantize from register-staged inputs, emit rowwise + colwise outputs ---- + // 2D block-scaling uses a per-tile (128x128) scalar scale, so the row-wise and + // column-wise quantized bytes are identical -- only the gmem layout differs. The + // columnwise buffer is physically transposed (cuBLAS FP8 block-scaling GEMM is + // TN-only), so we stage into smem_T and drain with stride-R_total stores. + if constexpr (kReturnColwise) { + constexpr int kSMemTRowStride = kTileDim + 4; + __shared__ OType smem_T[kTileDim][kSMemTRowStride]; + // Same delta for all 8 elements this thread writes since (thr_col*8 + e) >> 3 == thr_col. + const int swz_delta_w = smem_t_swz_delta(thr_col * kEltsPerThread); +#pragma unroll + for (int it = 0; it < kIters; ++it) { + const int row_local = thr_row + it * kRowsPerIter; + const size_t r = global_row_base + row_local; + const size_t c = global_col_base + thr_col * kEltsPerThread; + OVec qo; + quantize_row_vec(qo, staged[it], scale); + if constexpr (kReturnRowwise) { + if (c < K) { + const size_t count = (c + kEltsPerThread <= K) ? kEltsPerThread : (K - c); + qo.store_to_elts(output_base + r * K + c, 0, count); + } + } + const int c_phys = row_local ^ swz_delta_w; +#pragma unroll + for (int e = 0; e < kEltsPerThread; ++e) { + smem_T[thr_col * kEltsPerThread + e][c_phys] = qo.data.elt[e]; + } + } + __syncthreads(); + + // Drain smem_T to gmem: 4 cols/warp, 8 lanes/col, each lane stores a 16-row chunk + // so the 8 lanes of a col emit one 128 B gmem line. + constexpr int kStorePerChunk = 16; + constexpr int kRowChunksPerCol = kTileDim / kStorePerChunk; // 8 + constexpr int kColsPerIter = kThreadsPerBlock / kRowChunksPerCol; // 32 + constexpr int kColIters = kTileDim / kColsPerIter; // 4 + const int col_in_warp = lane / kRowChunksPerCol; + const int row_chunk = lane % kRowChunksPerCol; + const int out_row_off = row_chunk * kStorePerChunk; +#pragma unroll + for (int it = 0; it < kColIters; ++it) { + const int out_col_local = it * kColsPerIter + warp_id * 4 + col_in_warp; + const size_t out_col_global = global_col_base + out_col_local; + if (out_col_global < K) { + OType* out_ptr = + output_t_base + out_col_global * R_total + global_row_base + out_row_off; + // Per-byte unswizzled reads; LDS.128 is unsafe here because the 132 B smem_T row + // stride is not 16 B aligned for arbitrary column offsets. + const int swz_delta_r = smem_t_swz_delta(out_col_local); + Vec v; +#pragma unroll + for (int e = 0; e < kStorePerChunk; ++e) { + v.data.elt[e] = smem_T[out_col_local][(out_row_off + e) ^ swz_delta_r]; + } + v.store_to(out_ptr); + } + } + } else if constexpr (kReturnRowwise) { +#pragma unroll + for (int it = 0; it < kIters; ++it) { + const int row_local = thr_row + it * kRowsPerIter; + const size_t r = global_row_base + row_local; + const size_t c = global_col_base + thr_col * kEltsPerThread; + OVec qo; + quantize_row_vec(qo, staged[it], scale); + if (c < K) { + const size_t count = (c + kEltsPerThread <= K) ? kEltsPerThread : (K - c); + qo.store_to_elts(output_base + r * K + c, 0, count); + } + } + } +#endif // __CUDA_ARCH__ >= 900 && __CUDA_ARCH__ < 1000 +} + +// ----- 1D block scaling rowwise-only kernel ---------------------------------------------------- +// No smem cache. Each thread loads 16 cols/row, reduces amax across the 8 +// row-mates with shfl_xor, then quantizes and stores. + +template +__global__ void __launch_bounds__(kThreadsPerBlock) group_block_scaled_1d_rw_kernel( + const IType* __restrict__ input_base, OType* __restrict__ output_base, + CType* __restrict__ scale_inv_base, const int64_t* __restrict__ tensor_offsets_ptr, + const size_t num_tensors, const size_t common_first_dim_blocks, const size_t K, + const size_t total_row_blocks, const size_t scale_stride_aligned_R, const size_t R_total, + const float epsilon, const bool pow_2_scaling, const float* __restrict__ noop_ptr) { +#if __CUDA_ARCH__ >= 900 && __CUDA_ARCH__ < 1000 + if (noop_ptr != nullptr && noop_ptr[0] == 1.0f) return; + + const size_t tile_x = blockIdx.x; + const size_t tile_y_global = blockIdx.y; + if (tile_y_global >= total_row_blocks) return; + + const size_t tile_row_stride = static_cast(kTileDim) * K; + const size_t tensor_id = find_tensor_id_by_block_y( + tile_y_global, num_tensors, common_first_dim_blocks, tile_row_stride, tensor_offsets_ptr); + const size_t tensor_block_y_base = + kSameBothDims ? (tensor_id * common_first_dim_blocks) + : tensor_block_y_base_from_offsets(tensor_id, tensor_offsets_ptr, + tile_row_stride); + const size_t tensor_row_blocks = + kSameBothDims + ? common_first_dim_blocks + : (tensor_block_y_base_from_offsets(tensor_id + 1, tensor_offsets_ptr, tile_row_stride) - + tensor_block_y_base); + if (tile_y_global >= tensor_block_y_base + tensor_row_blocks) return; + + const size_t global_row_base = tile_y_global * kTileDim; + const size_t global_col_base = tile_x * kTileDim; + + // 8 threads per row x 16 cols/thread = one 128 B gmem cache line per row, 4 iters per tile. + constexpr int kThreadsPerRow = 8; + constexpr int kVec = 16; + constexpr int kRowsPerIter = kThreadsPerBlock / kThreadsPerRow; // 32 + constexpr int kIters = kTileDim / kRowsPerIter; // 4 + + const int tid = threadIdx.x; + const int thr_col = tid % kThreadsPerRow; // 0..7 + const int thr_row = tid / kThreadsPerRow; // 0..31 (row index within an iter) + const size_t c = global_col_base + static_cast(thr_col) * kVec; + + Vec in_vec[kIters]; + +#pragma unroll + for (int it = 0; it < kIters; ++it) { + const int row_local = thr_row + it * kRowsPerIter; + const size_t r_global = global_row_base + row_local; + + // Load this thread's 16 cols of row `row_local`. + if (c + kVec <= K) { + in_vec[it].load_from(input_base + r_global * K + c); + } else if (c < K) { + in_vec[it].load_from_elts(input_base + r_global * K + c, 0, K - c); + } else { + in_vec[it].clear(); + } + + CType amax = compute_row_amax(in_vec[it]); + amax = fmaxf(amax, __shfl_xor_sync(0xffffffff, amax, 1)); + amax = fmaxf(amax, __shfl_xor_sync(0xffffffff, amax, 2)); + amax = fmaxf(amax, __shfl_xor_sync(0xffffffff, amax, 4)); + + const CType scale = compute_scale_from_types(amax, epsilon, pow_2_scaling); + const CType scale_inv = 1.f / scale; + if (thr_col == 0 && r_global < R_total) { + scale_inv_base[r_global + tile_x * scale_stride_aligned_R] = scale_inv; + } + + if (r_global < R_total) { + Vec out_vec; + quantize_row_vec(out_vec, in_vec[it], scale); + if (c + kVec <= K) { + out_vec.store_to(output_base + r_global * K + c); + } else if (c < K) { + out_vec.store_to_elts(output_base + r_global * K + c, 0, K - c); + } + } + } +#endif // __CUDA_ARCH__ >= 900 && __CUDA_ARCH__ < 1000 +} + +// ----- 1D block scaling kernel with TMA input ------------------------------------------------ +// CW and BOTH path. TMA fills a 128x128 smem input cache. RW pass reads rows +// and stores quantized output. CW pass stages a column slice in registers, +// computes the per-column amax there, then fills smem_T and drains to gmem. + +template +__global__ void __launch_bounds__(kThreadsPerBlock) group_block_scaled_1d_tma_kernel( + const __grid_constant__ CUtensorMap tensor_map_input, OType* __restrict__ output_base, + OType* __restrict__ output_t_base, CType* __restrict__ scale_inv_base, + CType* __restrict__ scale_inv_t_base, const int64_t* __restrict__ tensor_offsets_ptr, + const size_t num_tensors, const size_t common_first_dim_blocks, const size_t K, + const size_t total_row_blocks, const size_t blocks_X, const size_t scale_stride_aligned_R, + const size_t scale_t_stride_aligned_K, const size_t R_total, const float epsilon, + const bool pow_2_scaling, const float* __restrict__ noop_ptr) { +#if __CUDA_ARCH__ >= 900 && __CUDA_ARCH__ < 1000 + if (noop_ptr != nullptr && noop_ptr[0] == 1.0f) return; + + const size_t tile_x = blockIdx.x; + const size_t tile_y_global = blockIdx.y; + if (tile_y_global >= total_row_blocks) return; + + const size_t tile_row_stride = static_cast(kTileDim) * K; + const size_t tensor_id = find_tensor_id_by_block_y( + tile_y_global, num_tensors, common_first_dim_blocks, tile_row_stride, tensor_offsets_ptr); + const size_t tensor_block_y_base = + kSameBothDims ? (tensor_id * common_first_dim_blocks) + : tensor_block_y_base_from_offsets(tensor_id, tensor_offsets_ptr, + tile_row_stride); + const size_t tensor_row_blocks = + kSameBothDims + ? common_first_dim_blocks + : (tensor_block_y_base_from_offsets(tensor_id + 1, tensor_offsets_ptr, tile_row_stride) - + tensor_block_y_base); + if (tile_y_global >= tensor_block_y_base + tensor_row_blocks) return; + + const size_t global_row_base = tile_y_global * kTileDim; + const size_t global_col_base = tile_x * kTileDim; + + // Dynamic smem: IType[kTileDim][kTileDim], 128 B aligned for TMA. Static smem + // (smem_T when CW, tma_mbar) lives outside the dynamic region. + extern __shared__ unsigned char smem_raw_1d_tma[]; + unsigned char* smem_base = align_smem_128(smem_raw_1d_tma); + IType (*smem)[kTileDim] = reinterpret_cast(smem_base); + + __shared__ uint64_t tma_mbar; + const int tid = threadIdx.x; + const bool leading_thread = (tid == 0); + + // ---- TMA async load of the input tile ---- + if (leading_thread) { + ptx::mbarrier_init(&tma_mbar, 1); + } + __syncthreads(); + if (leading_thread) { + constexpr uint32_t tx_bytes = kTileDim * kTileDim * sizeof(IType); + ptx::mbarrier_arrive_expect_tx(&tma_mbar, tx_bytes); + ptx::cp_async_bulk_tensor_2d_global_to_shared_cta( + reinterpret_cast(smem_base), + reinterpret_cast(&tensor_map_input), + static_cast(global_col_base), static_cast(global_row_base), &tma_mbar); + } + ptx::mbarrier_wait_parity(&tma_mbar, 0); + if (leading_thread) ptx::mbarrier_invalid(&tma_mbar); + __syncthreads(); + + // ---- RW pass (1x128 scale per row) ---- + // 8 t/row, vec-16 reads from smem; emits rowwise gmem directly. Only entered + // when CW is also requested (BOTH) -- RW-only requests use the dedicated + // group_block_scaled_1d_rw_kernel which skips the TMA load. + if constexpr (kReturnRowwise) { + constexpr int kThreadsPerRowRW = 8; + constexpr int kVec = 16; + constexpr int kRowsPerIterRW = kThreadsPerBlock / kThreadsPerRowRW; // 32 + constexpr int kRwIters = kTileDim / kRowsPerIterRW; // 4 + + const int rw_thr_col = tid % kThreadsPerRowRW; + const int rw_thr_row = tid / kThreadsPerRowRW; + const int col_local = rw_thr_col * kVec; + +#pragma unroll + for (int it = 0; it < kRwIters; ++it) { + const int row_local = rw_thr_row + it * kRowsPerIterRW; + Vec in_vec; + in_vec.load_from(&smem[row_local][col_local]); + + CType amax = compute_row_amax(in_vec); + amax = fmaxf(amax, __shfl_xor_sync(0xffffffff, amax, 1)); + amax = fmaxf(amax, __shfl_xor_sync(0xffffffff, amax, 2)); + amax = fmaxf(amax, __shfl_xor_sync(0xffffffff, amax, 4)); + + const CType scale = compute_scale_from_types(amax, epsilon, pow_2_scaling); + const CType scale_inv = 1.f / scale; + + const size_t r_global = global_row_base + row_local; + const bool row_in_bounds = (r_global < R_total); + + if (row_in_bounds && rw_thr_col == 0) { + scale_inv_base[r_global + tile_x * scale_stride_aligned_R] = scale_inv; + } + + if (row_in_bounds) { + const size_t cc = global_col_base + col_local; + Vec out_vec; + quantize_row_vec(out_vec, in_vec, scale); + if (cc + kVec <= K) { + out_vec.store_to(output_base + r_global * K + cc); + } else if (cc < K) { + out_vec.store_to_elts(output_base + r_global * K + cc, 0, K - cc); + } + } + } + } + + if constexpr (kReturnRowwise && kReturnColwise) { + __syncthreads(); + } + + // ---- CW pass (128x1 scale per column) ---- + // 2 t/col, 64-row column slice per thread. The columnwise buffer is physically + // transposed (cuBLAS FP8 block-scaling GEMM is TN-only), so we stage in smem_T + // and drain with stride-R_total stores. + if constexpr (kReturnColwise) { + constexpr int kSMemTRowStride = kTileDim + 4; + __shared__ OType smem_T[kTileDim][kSMemTRowStride]; + + constexpr int kThreadsPerColCW = 2; + constexpr int kRowsPerThreadCW = kTileDim / kThreadsPerColCW; + + const int col_local = tid / kThreadsPerColCW; + const int sub = tid % kThreadsPerColCW; + const int row_start = sub * kRowsPerThreadCW; + + // Stage the column slice in registers; the quantize step below reuses these + // instead of re-reading smem. + CType reg_data[kRowsPerThreadCW]; + CType amax = 0.f; +#pragma unroll + for (int e = 0; e < kRowsPerThreadCW; ++e) { + reg_data[e] = static_cast(smem[row_start + e][col_local]); + amax = fmaxf(amax, fabsf(reg_data[e])); + } + amax = fmaxf(amax, __shfl_xor_sync(0xffffffff, amax, 1)); + + const CType scale = compute_scale_from_types(amax, epsilon, pow_2_scaling); + const CType scale_inv = 1.f / scale; + + const size_t c_global = global_col_base + col_local; + const bool col_in_bounds = (c_global < K); + + if (col_in_bounds && sub == 0) { + if constexpr (kSwizzledScales) { + scale_inv_t_base[swizzled_colwise_scale_idx(c_global, tile_y_global, total_row_blocks)] = + scale_inv; + } else { + scale_inv_t_base[c_global + tile_y_global * scale_t_stride_aligned_K] = scale_inv; + } + } + + if (col_in_bounds) { +#pragma unroll + for (int e = 0; e < kRowsPerThreadCW; ++e) { + smem_T[col_local][row_start + e] = static_cast(reg_data[e] * scale); + } + } + __syncthreads(); + + drain_smem_t_1d_to_gmem( + smem_T, output_t_base, global_col_base, global_row_base, R_total, K, tid); + } +#endif // __CUDA_ARCH__ >= 900 && __CUDA_ARCH__ < 1000 +} + +// ----- Host-side dispatchers -------------------------------------------------------------------- + +inline size_t align_up_to(size_t x, size_t a) { return ((x + a - 1) / a) * a; } + +struct GroupedBlockwiseLaunchInfo { + size_t num_tensors; + size_t K; + size_t R_total; + size_t common_first_dim_blocks; + size_t total_row_blocks; + size_t blocks_X; + bool same_both_dims; + const int64_t* tensor_offsets_d = nullptr; +}; + +inline GroupedBlockwiseLaunchInfo prepare_grouped_blockwise_launch(const GroupedTensor* output) { + GroupedBlockwiseLaunchInfo info{}; + const bool same_both_dims = output->all_same_shape(); + const bool varying_first_dim = (!output->all_same_first_dim()) && output->all_same_last_dim(); + NVTE_CHECK(same_both_dims || varying_first_dim, + "Grouped FP8 block-scaling supports only SAME_BOTH_DIMS and VARYING_FIRST_DIM " + "shape representations."); + + info.same_both_dims = same_both_dims; + info.num_tensors = output->num_tensors; + info.K = output->get_common_last_dim(); + NVTE_CHECK(info.K % 16 == 0, "Last dim must be multiple of 16 (FP8 alignment)."); + + if (same_both_dims) { + const size_t common_first_dim = output->get_common_first_dim(); + NVTE_CHECK(common_first_dim % kTileDim == 0, + "SAME_BOTH_DIMS first dim must be multiple of 128."); + info.common_first_dim_blocks = common_first_dim / kTileDim; + info.R_total = info.num_tensors * common_first_dim; + } else { + info.common_first_dim_blocks = 0; + info.R_total = output->logical_shape.data[0]; + info.tensor_offsets_d = reinterpret_cast(output->tensor_offsets.dptr); + NVTE_CHECK(info.tensor_offsets_d != nullptr, + "VARYING_FIRST_DIM requires tensor_offsets to be set on the GroupedTensor."); + } + info.total_row_blocks = (info.R_total + kTileDim - 1) / kTileDim; + info.blocks_X = (info.K + kTileDim - 1) / kTileDim; + return info; +} + +// Public dispatch — 2D block scaling. +inline void group_quantize_blockwise_2d(const GroupedTensor* input, GroupedTensor* output, + const Tensor* noop, const float epsilon, + const bool pow_2_scaling, cudaStream_t stream) { + const int sm = transformer_engine::cuda::sm_arch(); + NVTE_CHECK(sm >= 90 && sm < 100, + "Grouped FP8 block-scaling quantize is only supported on Hopper (SM90-SM99); " + "use MXFP8 on Blackwell (SM100) or newer. Got SM", + sm, "."); + const bool use_rowwise = output->has_data(); + const bool use_colwise = output->has_columnwise_data(); + NVTE_CHECK(use_rowwise || use_colwise, + "Either rowwise or columnwise output data must be allocated."); + NVTE_CHECK(is_fp8_dtype(output->dtype()), "Output must be FP8."); + NVTE_CHECK(input->num_tensors == output->num_tensors, + "Input and output must have same num_tensors."); + + auto info = prepare_grouped_blockwise_launch(output); + if (info.R_total == 0 || info.K == 0) return; + + using CType = float; + const float* noop_ptr = + (noop != nullptr) ? reinterpret_cast(noop->data.dptr) : nullptr; + + const size_t scale_stride_y = align_up_to(info.blocks_X, 4); + // CW scales are stored [blocks_X, align4(total_row_blocks)] -- transposed to + // match the physically-transposed columnwise data the TN cuBLAS GEMM consumes. + const size_t scale_t_stride_y = align_up_to(info.total_row_blocks, 4); + const bool swizzled = output->with_gemm_swizzled_scales; + + dim3 grid(info.blocks_X, info.total_row_blocks, 1); + + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input->dtype(), IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + output->dtype(), OType, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + info.same_both_dims, kSameBothDims, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + use_rowwise, kRowwise, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + use_colwise, kColwise, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + swizzled, kSwizzled, + if constexpr (kRowwise || kColwise) { + CUtensorMap tensor_map_input{}; + create_2D_tensor_map(tensor_map_input, input->data, info.R_total, + info.K, kTileDim, kTileDim, info.K, 0, + sizeof(IType) * 8); + auto tma_kernel = + group_block_scaled_2d_tma_kernel; + const size_t smem_bytes = kTileDim * kTileDim * sizeof(IType) + + TMA_SHMEM_ALIGNMENT - 1; + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + tma_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + static_cast(smem_bytes))); + tma_kernel<<>>( + tensor_map_input, + kRowwise ? reinterpret_cast(output->data.dptr) : nullptr, + kColwise ? reinterpret_cast( + output->columnwise_data.dptr) + : nullptr, + kRowwise ? reinterpret_cast(output->scale_inv.dptr) + : nullptr, + kColwise ? reinterpret_cast( + output->columnwise_scale_inv.dptr) + : nullptr, + info.tensor_offsets_d, info.num_tensors, + info.common_first_dim_blocks, info.K, info.total_row_blocks, + info.blocks_X, scale_stride_y, scale_t_stride_y, info.R_total, + epsilon, pow_2_scaling, noop_ptr); + })))))); + + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +// Public dispatch — 1D block scaling. +inline void group_quantize_blockwise_1d(const GroupedTensor* input, GroupedTensor* output, + const Tensor* noop, const float epsilon, + const bool pow_2_scaling, cudaStream_t stream) { + const int sm = transformer_engine::cuda::sm_arch(); + NVTE_CHECK(sm >= 90 && sm < 100, + "Grouped FP8 block-scaling quantize is only supported on Hopper (SM90); " + "use MXFP8 on Blackwell (SM100) or newer. Got SM", + sm, "."); + const bool use_rowwise = output->has_data(); + const bool use_colwise = output->has_columnwise_data(); + NVTE_CHECK(use_rowwise || use_colwise, + "Either rowwise or columnwise output data must be allocated."); + NVTE_CHECK(is_fp8_dtype(output->dtype()), "Output must be FP8."); + NVTE_CHECK(input->num_tensors == output->num_tensors, + "Input and output must have same num_tensors."); + + auto info = prepare_grouped_blockwise_launch(output); + if (info.R_total == 0 || info.K == 0) return; + + using CType = float; + const float* noop_ptr = + (noop != nullptr) ? reinterpret_cast(noop->data.dptr) : nullptr; + + const size_t scale_stride_aligned_R = align_up_to(info.R_total, 4); + const size_t scale_t_stride_aligned_K = align_up_to(info.K, 4); + const bool swizzled = output->with_gemm_swizzled_scales; + + dim3 grid(info.blocks_X, info.total_row_blocks, 1); + + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input->dtype(), IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + output->dtype(), OType, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + info.same_both_dims, kSameBothDims, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + use_rowwise, kRowwise, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + use_colwise, kColwise, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + swizzled, kSwizzled, + if constexpr (kRowwise && !kColwise) { + // RW-only: no smem cache, no TMA benefit, no colwise scales. + group_block_scaled_1d_rw_kernel + <<>>( + reinterpret_cast(input->data.dptr), + reinterpret_cast(output->data.dptr), + reinterpret_cast(output->scale_inv.dptr), + info.tensor_offsets_d, info.num_tensors, + info.common_first_dim_blocks, info.K, info.total_row_blocks, + scale_stride_aligned_R, info.R_total, epsilon, pow_2_scaling, + noop_ptr); + } else if constexpr (kRowwise || kColwise) { + // CW-only or BOTH: smem-cached kernel with TMA bulk load. + const size_t smem_bytes = kTileDim * kTileDim * sizeof(IType); + constexpr size_t kStaticSmemCWBytes = + (kTileDim * (kTileDim + 4)) * sizeof(OType); + const size_t static_smem_bytes = kColwise ? kStaticSmemCWBytes : 0; + const size_t tma_smem_bytes = + smem_bytes + TMA_SHMEM_ALIGNMENT - 1; + const size_t total_smem_tma = tma_smem_bytes + static_smem_bytes; + auto tma_kernel = + group_block_scaled_1d_tma_kernel; + if (total_smem_tma >= 48 * 1024) { + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + tma_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + static_cast(tma_smem_bytes))); + } + CUtensorMap tensor_map_input{}; + create_2D_tensor_map(tensor_map_input, input->data, info.R_total, + info.K, kTileDim, kTileDim, info.K, 0, + sizeof(IType) * 8); + tma_kernel<<>>( + tensor_map_input, + kRowwise ? reinterpret_cast(output->data.dptr) : nullptr, + kColwise ? reinterpret_cast( + output->columnwise_data.dptr) + : nullptr, + kRowwise ? reinterpret_cast(output->scale_inv.dptr) + : nullptr, + kColwise ? reinterpret_cast( + output->columnwise_scale_inv.dptr) + : nullptr, + info.tensor_offsets_d, info.num_tensors, + info.common_first_dim_blocks, info.K, info.total_row_blocks, + info.blocks_X, scale_stride_aligned_R, scale_t_stride_aligned_K, + info.R_total, epsilon, pow_2_scaling, noop_ptr); + })))))); + + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +} // namespace fp8_blockwise +} // namespace dispatch +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_GROUP_QUANTIZE_FP8_BLOCKWISE_CUH_ diff --git a/transformer_engine/common/util/ptx.cuh b/transformer_engine/common/util/ptx.cuh index 88a57fe989..2814aa3490 100644 --- a/transformer_engine/common/util/ptx.cuh +++ b/transformer_engine/common/util/ptx.cuh @@ -128,22 +128,22 @@ constexpr bool is_supported_arch() { // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-init __device__ __forceinline__ void mbarrier_init(uint64_t *mbar, const uint32_t count) { -#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); asm volatile("mbarrier.init.shared.b64 [%0], %1;" ::"r"(mbar_ptr), "r"(count) : "memory"); #else - NVTE_DEVICE_ERROR("mbarrier_init is only supported on SM 10.0+."); -#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + NVTE_DEVICE_ERROR("mbarrier_init is only supported on SM 9.0+."); +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) } // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-inval __device__ __forceinline__ void mbarrier_invalid(uint64_t *mbar) { -#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); asm volatile("mbarrier.inval.shared.b64 [%0];" ::"r"(mbar_ptr) : "memory"); #else - NVTE_DEVICE_ERROR("mbarrier_invalid is only supported on SM 10.0+."); -#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + NVTE_DEVICE_ERROR("mbarrier_invalid is only supported on SM 9.0+."); +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) } // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive @@ -158,13 +158,13 @@ __device__ __forceinline__ void mbarrier_arrive(uint64_t *mbar) { // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive __device__ __forceinline__ void mbarrier_arrive_expect_tx(uint64_t *mbar, const uint32_t tx_count) { -#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); asm volatile("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;" ::"r"(mbar_ptr), "r"(tx_count) : "memory"); #else - NVTE_DEVICE_ERROR("mbarrier_arrive_expect_tx is only supported on SM 10.0+."); -#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + NVTE_DEVICE_ERROR("mbarrier_arrive_expect_tx is only supported on SM 9.0+."); +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) } __device__ __forceinline__ void mbarrier_arrive_expect_tx_cta_relaxed_shared_cta( @@ -230,8 +230,26 @@ __device__ __forceinline__ void cp_async_bulk_tensor_2d_global_to_shared( #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } +// global -> shared::cta (no cluster; valid on Hopper sm_90+ and Blackwell with +// cluster size 1). Used by the FP8 block-scaling grouped quantize kernels. +__device__ __forceinline__ void cp_async_bulk_tensor_2d_global_to_shared_cta( + uint64_t *dst_shmem, const uint64_t *tensor_map_ptr, const uint32_t offset_x, + const uint32_t offset_y, uint64_t *mbar) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) + uint32_t dst_shmem_ptr = __cvta_generic_to_shared(dst_shmem); + uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); + asm volatile( + "cp.async.bulk.tensor.2d.shared::cta.global.tile" + ".mbarrier::complete_tx::bytes [%0], [%1, {%2, %3}], [%4];" ::"r"(dst_shmem_ptr), + "l"(tensor_map_ptr), "r"(offset_x), "r"(offset_y), "r"(mbar_ptr) + : "memory"); +#else + NVTE_DEVICE_ERROR("cp_async_bulk_tensor_2d_global_to_shared_cta is only supported on SM 9.0+."); +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) +} + __device__ __forceinline__ bool mbarrier_try_wait_parity(uint32_t mbar_ptr, const uint32_t parity) { -#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) uint32_t waitComplete; asm volatile( "{\n\t .reg .pred P_OUT; \n\t" @@ -243,19 +261,19 @@ __device__ __forceinline__ bool mbarrier_try_wait_parity(uint32_t mbar_ptr, cons : "memory"); return static_cast(waitComplete); #else - NVTE_DEVICE_ERROR("mbarrier_try_wait_parity is only supported on SM 10.0+."); -#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + NVTE_DEVICE_ERROR("mbarrier_try_wait_parity is only supported on SM 9.0+."); +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) return true; } __device__ __forceinline__ void mbarrier_wait_parity(uint64_t *mbar, const uint32_t parity) { -#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); while (!mbarrier_try_wait_parity(mbar_ptr, parity)) { } #else - NVTE_DEVICE_ERROR("mbarrier_wait_parity is only supported on SM 10.0+."); -#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + NVTE_DEVICE_ERROR("mbarrier_wait_parity is only supported on SM 9.0+."); +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) } __device__ __forceinline__ void mbarrier_wait_parity_acquire_cta_shared_cta(uint64_t *mbar, diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index aab5a87b9a..e73886e5ca 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -241,6 +241,7 @@ py::object group_quantize(const at::Tensor &tensor, py::handle quantizer, const enum class GroupedQuantizationMode { MXFP8_GROUPED_QUANTIZE, NVFP4_GROUPED_QUANTIZE, + FP8_BLOCKWISE_GROUPED_QUANTIZE, INVALID_FOR_GROUPED_QUANTIZE }; GroupedQuantizationMode grouped_quantization_mode = @@ -249,6 +250,8 @@ py::object group_quantize(const at::Tensor &tensor, py::handle quantizer, const grouped_quantization_mode = GroupedQuantizationMode::MXFP8_GROUPED_QUANTIZE; } else if (detail::IsNVFP4Quantizers(quantizer.ptr())) { grouped_quantization_mode = GroupedQuantizationMode::NVFP4_GROUPED_QUANTIZE; + } else if (detail::IsFloat8BlockwiseQuantizers(quantizer.ptr())) { + grouped_quantization_mode = GroupedQuantizationMode::FP8_BLOCKWISE_GROUPED_QUANTIZE; } if (empty_input_buffer) { @@ -274,9 +277,21 @@ py::object group_quantize(const at::Tensor &tensor, py::handle quantizer, const }); break; } + case GroupedQuantizationMode::FP8_BLOCKWISE_GROUPED_QUANTIZE: { + Float8BlockQuantizer *fp8_block_quantizer_cpp = + static_cast(quantizer_cpp.get()); + QuantizationConfigWrapper quant_config_cpp; + quant_config_cpp.set_force_pow_2_scales(fp8_block_quantizer_cpp->force_pow_2_scales); + quant_config_cpp.set_amax_epsilon(fp8_block_quantizer_cpp->amax_epsilon); + NVTE_SCOPED_GIL_RELEASE({ + nvte_group_quantize(grouped_input_tensor.data(), grouped_output_tensor_cpp.data(), + quant_config_cpp, at::cuda::getCurrentCUDAStream()); + }); + break; + } case GroupedQuantizationMode::INVALID_FOR_GROUPED_QUANTIZE: default: - NVTE_ERROR("group_quantize: only support NVFP4 or MXFP8 quantizer."); + NVTE_ERROR("group_quantize: only support NVFP4, MXFP8, or Float8Blockwise quantizer."); break; } From 0a6b105d117c6c4bc4cd335314c5b75b0516d9cc Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 17 Jun 2026 13:02:34 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../group_quantize_fp8_blockwise.cuh | 114 ++++++++---------- 1 file changed, 53 insertions(+), 61 deletions(-) diff --git a/transformer_engine/common/cast/fp8_blockwise/group_quantize_fp8_blockwise.cuh b/transformer_engine/common/cast/fp8_blockwise/group_quantize_fp8_blockwise.cuh index b7b49ebd9a..7b7ac5563a 100644 --- a/transformer_engine/common/cast/fp8_blockwise/group_quantize_fp8_blockwise.cuh +++ b/transformer_engine/common/cast/fp8_blockwise/group_quantize_fp8_blockwise.cuh @@ -57,10 +57,10 @@ __device__ __forceinline__ unsigned char* align_smem_128(unsigned char* p) { // j = row-block index (tile_y_global) // num_tiles_X = DIVUP(total_row_blocks, TILE_DIM_X=4) __device__ __forceinline__ size_t swizzled_colwise_scale_idx(size_t i, size_t j, - size_t total_row_blocks) { + size_t total_row_blocks) { using namespace transformer_engine::dispatch::mxfp8::swizzle; - const size_t num_tiles_X = (total_row_blocks + GEMM_SWIZZLED_SCALE_TILE_DIM_X - 1) / - GEMM_SWIZZLED_SCALE_TILE_DIM_X; + const size_t num_tiles_X = + (total_row_blocks + GEMM_SWIZZLED_SCALE_TILE_DIM_X - 1) / GEMM_SWIZZLED_SCALE_TILE_DIM_X; return gemm_swizzled_scale_idx(i, j, num_tiles_X); } @@ -70,9 +70,8 @@ __device__ __forceinline__ size_t swizzled_colwise_scale_idx(size_t i, size_t j, // `tensor_offsets_ptr` (element offsets) directly on device. template __device__ __forceinline__ size_t find_tensor_id_by_block_y( - const size_t block_y_global, const size_t num_tensors, - const size_t common_first_dim_blocks, const size_t tile_row_stride, - const int64_t* __restrict__ tensor_offsets_ptr) { + const size_t block_y_global, const size_t num_tensors, const size_t common_first_dim_blocks, + const size_t tile_row_stride, const int64_t* __restrict__ tensor_offsets_ptr) { if constexpr (kSameBothDims) { return block_y_global / common_first_dim_blocks; } @@ -122,8 +121,8 @@ __device__ __forceinline__ CType compute_row_amax(const Vec& v) { // Per-vector multiply-and-quantize via fp32 intermediates. template -__device__ __forceinline__ void quantize_row_vec(Vec& out, - const Vec& in, CType scale) { +__device__ __forceinline__ void quantize_row_vec(Vec& out, const Vec& in, + CType scale) { #pragma unroll for (int e = 0; e < kVec; ++e) { out.data.elt[e] = static_cast(static_cast(in.data.elt[e]) * scale); @@ -139,14 +138,16 @@ __device__ __forceinline__ int smem_t_swz_delta(int smem_t_row) { // Drain smem_T to gmem for the 1D CW path: 4 cols/warp, 8 lanes/col, each lane // stores a 16-row chunk so the 8 lanes of a col emit one 128 B gmem line. template -__device__ __forceinline__ void drain_smem_t_1d_to_gmem( - OType (&smem_T)[kTileDim][kSMemTRowStride], - OType* __restrict__ output_t_base, const size_t global_col_base, - const size_t global_row_base, const size_t R_total, const size_t K, const int tid) { +__device__ __forceinline__ void drain_smem_t_1d_to_gmem(OType (&smem_T)[kTileDim][kSMemTRowStride], + OType* __restrict__ output_t_base, + const size_t global_col_base, + const size_t global_row_base, + const size_t R_total, const size_t K, + const int tid) { constexpr int kStorePerChunk = 16; - constexpr int kRowChunksPerCol = kTileDim / kStorePerChunk; // 8 - constexpr int kColsPerIter = kThreadsPerBlock / kRowChunksPerCol; // 32 - constexpr int kColIters = kTileDim / kColsPerIter; // 4 + constexpr int kRowChunksPerCol = kTileDim / kStorePerChunk; // 8 + constexpr int kColsPerIter = kThreadsPerBlock / kRowChunksPerCol; // 32 + constexpr int kColIters = kTileDim / kColsPerIter; // 4 const int warp_id = tid / kThreadsPerWarp; const int lane = tid % kThreadsPerWarp; const int col_in_warp = lane / kRowChunksPerCol; @@ -157,8 +158,7 @@ __device__ __forceinline__ void drain_smem_t_1d_to_gmem( const int out_col_local = it * kColsPerIter + warp_id * 4 + col_in_warp; const size_t out_col_global = global_col_base + out_col_local; if (out_col_global < K) { - OType* out_ptr = - output_t_base + out_col_global * R_total + global_row_base + out_row_off; + OType* out_ptr = output_t_base + out_col_global * R_total + global_row_base + out_row_off; Vec v; #pragma unroll for (int e = 0; e < kStorePerChunk; ++e) { @@ -195,9 +195,9 @@ __global__ void __launch_bounds__(kThreadsPerBlock, 4) group_block_scaled_2d_tma const size_t tensor_id = find_tensor_id_by_block_y( tile_y_global, num_tensors, common_first_dim_blocks, tile_row_stride, tensor_offsets_ptr); const size_t tensor_block_y_base = - kSameBothDims ? (tensor_id * common_first_dim_blocks) - : tensor_block_y_base_from_offsets(tensor_id, tensor_offsets_ptr, - tile_row_stride); + kSameBothDims + ? (tensor_id * common_first_dim_blocks) + : tensor_block_y_base_from_offsets(tensor_id, tensor_offsets_ptr, tile_row_stride); const size_t tensor_row_blocks = kSameBothDims ? common_first_dim_blocks @@ -211,8 +211,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock, 4) group_block_scaled_2d_tma // Dynamic smem holds the IType input tile (TMA dest, must be 128 B aligned). // warp_amaxes and tma_mbar are static smem. extern __shared__ unsigned char smem_raw_2d_tma[]; - IType (*smem_in)[kTileDim] = - reinterpret_cast(align_smem_128(smem_raw_2d_tma)); + IType(*smem_in)[kTileDim] = reinterpret_cast(align_smem_128(smem_raw_2d_tma)); __shared__ CType warp_amaxes[kNumWarps]; __shared__ uint64_t tma_mbar; @@ -238,9 +237,9 @@ __global__ void __launch_bounds__(kThreadsPerBlock, 4) group_block_scaled_2d_tma // ---- Pass 1: tile amax, staging input vectors in registers for reuse in pass 2 ---- constexpr int kEltsPerThread = 8; - constexpr int kThreadsPerRow = kTileDim / kEltsPerThread; // 16 + constexpr int kThreadsPerRow = kTileDim / kEltsPerThread; // 16 constexpr int kRowsPerIter = kThreadsPerBlock / kThreadsPerRow; // 16 - constexpr int kIters = kTileDim / kRowsPerIter; // 8 + constexpr int kIters = kTileDim / kRowsPerIter; // 8 const int thr_col = tid % kThreadsPerRow; const int thr_row = tid / kThreadsPerRow; @@ -320,9 +319,9 @@ __global__ void __launch_bounds__(kThreadsPerBlock, 4) group_block_scaled_2d_tma // Drain smem_T to gmem: 4 cols/warp, 8 lanes/col, each lane stores a 16-row chunk // so the 8 lanes of a col emit one 128 B gmem line. constexpr int kStorePerChunk = 16; - constexpr int kRowChunksPerCol = kTileDim / kStorePerChunk; // 8 + constexpr int kRowChunksPerCol = kTileDim / kStorePerChunk; // 8 constexpr int kColsPerIter = kThreadsPerBlock / kRowChunksPerCol; // 32 - constexpr int kColIters = kTileDim / kColsPerIter; // 4 + constexpr int kColIters = kTileDim / kColsPerIter; // 4 const int col_in_warp = lane / kRowChunksPerCol; const int row_chunk = lane % kRowChunksPerCol; const int out_row_off = row_chunk * kStorePerChunk; @@ -331,8 +330,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock, 4) group_block_scaled_2d_tma const int out_col_local = it * kColsPerIter + warp_id * 4 + col_in_warp; const size_t out_col_global = global_col_base + out_col_local; if (out_col_global < K) { - OType* out_ptr = - output_t_base + out_col_global * R_total + global_row_base + out_row_off; + OType* out_ptr = output_t_base + out_col_global * R_total + global_row_base + out_row_off; // Per-byte unswizzled reads; LDS.128 is unsafe here because the 132 B smem_T row // stride is not 16 B aligned for arbitrary column offsets. const int swz_delta_r = smem_t_swz_delta(out_col_local); @@ -383,9 +381,9 @@ __global__ void __launch_bounds__(kThreadsPerBlock) group_block_scaled_1d_rw_ker const size_t tensor_id = find_tensor_id_by_block_y( tile_y_global, num_tensors, common_first_dim_blocks, tile_row_stride, tensor_offsets_ptr); const size_t tensor_block_y_base = - kSameBothDims ? (tensor_id * common_first_dim_blocks) - : tensor_block_y_base_from_offsets(tensor_id, tensor_offsets_ptr, - tile_row_stride); + kSameBothDims + ? (tensor_id * common_first_dim_blocks) + : tensor_block_y_base_from_offsets(tensor_id, tensor_offsets_ptr, tile_row_stride); const size_t tensor_row_blocks = kSameBothDims ? common_first_dim_blocks @@ -403,8 +401,8 @@ __global__ void __launch_bounds__(kThreadsPerBlock) group_block_scaled_1d_rw_ker constexpr int kIters = kTileDim / kRowsPerIter; // 4 const int tid = threadIdx.x; - const int thr_col = tid % kThreadsPerRow; // 0..7 - const int thr_row = tid / kThreadsPerRow; // 0..31 (row index within an iter) + const int thr_col = tid % kThreadsPerRow; // 0..7 + const int thr_row = tid / kThreadsPerRow; // 0..31 (row index within an iter) const size_t c = global_col_base + static_cast(thr_col) * kVec; Vec in_vec[kIters]; @@ -473,9 +471,9 @@ __global__ void __launch_bounds__(kThreadsPerBlock) group_block_scaled_1d_tma_ke const size_t tensor_id = find_tensor_id_by_block_y( tile_y_global, num_tensors, common_first_dim_blocks, tile_row_stride, tensor_offsets_ptr); const size_t tensor_block_y_base = - kSameBothDims ? (tensor_id * common_first_dim_blocks) - : tensor_block_y_base_from_offsets(tensor_id, tensor_offsets_ptr, - tile_row_stride); + kSameBothDims + ? (tensor_id * common_first_dim_blocks) + : tensor_block_y_base_from_offsets(tensor_id, tensor_offsets_ptr, tile_row_stride); const size_t tensor_row_blocks = kSameBothDims ? common_first_dim_blocks @@ -490,7 +488,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) group_block_scaled_1d_tma_ke // (smem_T when CW, tma_mbar) lives outside the dynamic region. extern __shared__ unsigned char smem_raw_1d_tma[]; unsigned char* smem_base = align_smem_128(smem_raw_1d_tma); - IType (*smem)[kTileDim] = reinterpret_cast(smem_base); + IType(*smem)[kTileDim] = reinterpret_cast(smem_base); __shared__ uint64_t tma_mbar; const int tid = threadIdx.x; @@ -614,8 +612,8 @@ __global__ void __launch_bounds__(kThreadsPerBlock) group_block_scaled_1d_tma_ke } __syncthreads(); - drain_smem_t_1d_to_gmem( - smem_T, output_t_base, global_col_base, global_row_base, R_total, K, tid); + drain_smem_t_1d_to_gmem(smem_T, output_t_base, global_col_base, + global_row_base, R_total, K, tid); } #endif // __CUDA_ARCH__ >= 900 && __CUDA_ARCH__ < 1000 } @@ -709,32 +707,29 @@ inline void group_quantize_blockwise_2d(const GroupedTensor* input, GroupedTenso TRANSFORMER_ENGINE_SWITCH_CONDITION( use_colwise, kColwise, TRANSFORMER_ENGINE_SWITCH_CONDITION( - swizzled, kSwizzled, - if constexpr (kRowwise || kColwise) { + swizzled, kSwizzled, if constexpr (kRowwise || kColwise) { CUtensorMap tensor_map_input{}; create_2D_tensor_map(tensor_map_input, input->data, info.R_total, info.K, kTileDim, kTileDim, info.K, 0, sizeof(IType) * 8); auto tma_kernel = - group_block_scaled_2d_tma_kernel; - const size_t smem_bytes = kTileDim * kTileDim * sizeof(IType) + - TMA_SHMEM_ALIGNMENT - 1; + group_block_scaled_2d_tma_kernel; + const size_t smem_bytes = + kTileDim * kTileDim * sizeof(IType) + TMA_SHMEM_ALIGNMENT - 1; NVTE_CHECK_CUDA(cudaFuncSetAttribute( tma_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, static_cast(smem_bytes))); tma_kernel<<>>( tensor_map_input, kRowwise ? reinterpret_cast(output->data.dptr) : nullptr, - kColwise ? reinterpret_cast( - output->columnwise_data.dptr) + kColwise ? reinterpret_cast(output->columnwise_data.dptr) : nullptr, kRowwise ? reinterpret_cast(output->scale_inv.dptr) : nullptr, - kColwise ? reinterpret_cast( - output->columnwise_scale_inv.dptr) - : nullptr, + kColwise + ? reinterpret_cast(output->columnwise_scale_inv.dptr) + : nullptr, info.tensor_offsets_d, info.num_tensors, info.common_first_dim_blocks, info.K, info.total_row_blocks, info.blocks_X, scale_stride_y, scale_t_stride_y, info.R_total, @@ -803,13 +798,11 @@ inline void group_quantize_blockwise_1d(const GroupedTensor* input, GroupedTenso constexpr size_t kStaticSmemCWBytes = (kTileDim * (kTileDim + 4)) * sizeof(OType); const size_t static_smem_bytes = kColwise ? kStaticSmemCWBytes : 0; - const size_t tma_smem_bytes = - smem_bytes + TMA_SHMEM_ALIGNMENT - 1; + const size_t tma_smem_bytes = smem_bytes + TMA_SHMEM_ALIGNMENT - 1; const size_t total_smem_tma = tma_smem_bytes + static_smem_bytes; auto tma_kernel = - group_block_scaled_1d_tma_kernel; + group_block_scaled_1d_tma_kernel; if (total_smem_tma >= 48 * 1024) { NVTE_CHECK_CUDA(cudaFuncSetAttribute( tma_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, @@ -822,14 +815,13 @@ inline void group_quantize_blockwise_1d(const GroupedTensor* input, GroupedTenso tma_kernel<<>>( tensor_map_input, kRowwise ? reinterpret_cast(output->data.dptr) : nullptr, - kColwise ? reinterpret_cast( - output->columnwise_data.dptr) + kColwise ? reinterpret_cast(output->columnwise_data.dptr) : nullptr, kRowwise ? reinterpret_cast(output->scale_inv.dptr) : nullptr, - kColwise ? reinterpret_cast( - output->columnwise_scale_inv.dptr) - : nullptr, + kColwise + ? reinterpret_cast(output->columnwise_scale_inv.dptr) + : nullptr, info.tensor_offsets_d, info.num_tensors, info.common_first_dim_blocks, info.K, info.total_row_blocks, info.blocks_X, scale_stride_aligned_R, scale_t_stride_aligned_K,