diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index 9b67c09f34..d5c446fb48 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -24,6 +24,7 @@ add_executable(test_operator test_cast_transpose_dbias_dgelu.cu test_cast_transpose_dgeglu.cu test_act.cu + test_scaled_activation.cu test_normalization.cu test_normalization_mxfp8.cu test_memset.cu diff --git a/tests/cpp/operator/test_scaled_activation.cu b/tests/cpp/operator/test_scaled_activation.cu new file mode 100644 index 0000000000..72a64a3c04 --- /dev/null +++ b/tests/cpp/operator/test_scaled_activation.cu @@ -0,0 +1,321 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include + +#include +#include + +#include + +#include "../test_common.h" + +using namespace transformer_engine; + +namespace { + +enum class ScaledActivationCase { + kSwiGLU, + kClampedSwiGLU, + kSReLU, +}; + +constexpr float kClampedLimit = 1.3f; +constexpr float kClampedAlpha = 1.702f; +constexpr float kClampedLinearOffset = 0.5f; + +const char *activation_name(ScaledActivationCase activation) { + switch (activation) { + case ScaledActivationCase::kSwiGLU: + return "scaled_swiglu"; + case ScaledActivationCase::kClampedSwiGLU: + return "scaled_clamped_swiglu"; + case ScaledActivationCase::kSReLU: + return "scaled_srelu"; + } + return "unknown"; +} + +inline float sigmoid(const float x) { return 1.0f / (1.0f + expf(-x)); } + +inline float qgelu_alpha(const float x, const float alpha) { return x * sigmoid(alpha * x); } + +inline float dqgelu_alpha(const float x, const float alpha) { + const float sig = sigmoid(alpha * x); + return alpha * x * sig * (1.0f - sig) + sig; +} + +inline float silu_ref(const float x) { return x * sigmoid(x); } + +inline float dsilu_ref(const float x) { + const float sig = sigmoid(x); + return x * sig * (1.0f - sig) + sig; +} + +inline float srelu_ref(const float x) { return x > 0.0f ? x * x : 0.0f; } + +inline float dsrelu_ref(const float x) { return fmaxf(0.0f, 2.0f * x); } + +inline void glu_indices(const size_t row, const size_t col, const size_t hidden, + const int64_t interleave, size_t *act_idx, size_t *linear_idx) { + if (interleave > 0) { + const size_t block = col / static_cast(interleave); + const size_t lane = col % static_cast(interleave); + const size_t base = row * hidden * 2 + block * static_cast(interleave) * 2 + lane; + *act_idx = base; + *linear_idx = base + static_cast(interleave); + } else { + const size_t base = row * hidden * 2; + *act_idx = base + col; + *linear_idx = base + hidden + col; + } +} + +inline float gated_unscaled(const ScaledActivationCase activation, const float act_in, + const float linear_in) { + switch (activation) { + case ScaledActivationCase::kSwiGLU: + return silu_ref(act_in) * linear_in; + case ScaledActivationCase::kClampedSwiGLU: { + const float act = qgelu_alpha(fminf(kClampedLimit, act_in), kClampedAlpha); + const float linear = + fminf(fmaxf(-kClampedLimit, linear_in), kClampedLimit) + kClampedLinearOffset; + return act * linear; + } + case ScaledActivationCase::kSReLU: + return srelu_ref(act_in); + } + return 0.0f; +} + +inline void gated_grads(const ScaledActivationCase activation, const float act_in, + const float linear_in, float *dact, float *dlinear, float *unscaled) { + switch (activation) { + case ScaledActivationCase::kSwiGLU: { + const float act = silu_ref(act_in); + *unscaled = act * linear_in; + *dact = dsilu_ref(act_in) * linear_in; + *dlinear = act; + return; + } + case ScaledActivationCase::kClampedSwiGLU: { + const bool dlinear_mask = linear_in <= kClampedLimit && linear_in >= -kClampedLimit; + const float act = qgelu_alpha(fminf(kClampedLimit, act_in), kClampedAlpha); + const float dact_base = + act_in <= kClampedLimit ? dqgelu_alpha(fminf(kClampedLimit, act_in), kClampedAlpha) + : 0.0f; + const float linear = + fminf(fmaxf(-kClampedLimit, linear_in), kClampedLimit) + kClampedLinearOffset; + *unscaled = act * linear; + *dact = dact_base * linear; + *dlinear = dlinear_mask ? act : 0.0f; + return; + } + case ScaledActivationCase::kSReLU: + *unscaled = srelu_ref(act_in); + *dact = dsrelu_ref(act_in); + *dlinear = 0.0f; + return; + } +} + +template +void compute_reference(ScaledActivationCase activation, const DataT *input, const ScaleT *scales, + const DataT *grad_output, DataT *output, DataT *grad_input, + DataT *grad_scales, const size_t rows, const size_t hidden, + const int64_t interleave, const bool compute_grad_scales) { + const bool is_gated = activation != ScaledActivationCase::kSReLU; + const size_t input_cols = is_gated ? hidden * 2 : hidden; + std::fill(grad_input, grad_input + rows * input_cols, static_cast(0.0f)); + + for (size_t row = 0; row < rows; ++row) { + const float scale = static_cast(scales[row]); + float scale_grad = 0.0f; + for (size_t col = 0; col < hidden; ++col) { + const size_t out_idx = row * hidden + col; + float unscaled = 0.0f; + float dact = 0.0f; + float dlinear = 0.0f; + if (is_gated) { + size_t act_idx = 0; + size_t linear_idx = 0; + glu_indices(row, col, hidden, interleave, &act_idx, &linear_idx); + const float act_in = static_cast(input[act_idx]); + const float linear_in = static_cast(input[linear_idx]); + unscaled = gated_unscaled(activation, act_in, linear_in); + gated_grads(activation, act_in, linear_in, &dact, &dlinear, &unscaled); + + const float scaled_grad = static_cast(grad_output[out_idx]) * scale; + grad_input[act_idx] = static_cast(scaled_grad * dact); + grad_input[linear_idx] = static_cast(scaled_grad * dlinear); + } else { + const float x = static_cast(input[out_idx]); + unscaled = srelu_ref(x); + const float scaled_grad = static_cast(grad_output[out_idx]) * scale; + grad_input[out_idx] = static_cast(scaled_grad * dsrelu_ref(x)); + } + + output[out_idx] = static_cast(unscaled * scale); + scale_grad += static_cast(grad_output[out_idx]) * unscaled; + } + if (compute_grad_scales) { + grad_scales[row] = static_cast(scale_grad); + } + } +} + +template +void run_scaled_activation_test(ScaledActivationCase activation, const size_t rows, + const size_t hidden, const int64_t interleave, + const bool compute_grad_scales) { + using namespace test; + const DType data_type = TypeInfo::dtype; + const DType scale_type = TypeInfo::dtype; + const bool is_gated = activation != ScaledActivationCase::kSReLU; + const size_t input_cols = is_gated ? hidden * 2 : hidden; + + Tensor input("input", std::vector{rows, input_cols}, data_type); + Tensor scales("act_scales", std::vector{rows}, scale_type); + Tensor output("output", std::vector{rows, hidden}, data_type); + Tensor grad_output("grad_output", std::vector{rows, hidden}, data_type); + Tensor grad_input("grad_input", std::vector{rows, input_cols}, data_type); + Tensor grad_scales("grad_scales", std::vector{rows}, data_type); + + fillUniform(&input); + fillUniform(&scales); + fillUniform(&grad_output); + + std::unique_ptr ref_output = std::make_unique(rows * hidden); + std::unique_ptr ref_grad_input = std::make_unique(rows * input_cols); + std::unique_ptr ref_grad_scales = std::make_unique(rows); + + compute_reference(activation, input.rowwise_cpu_dptr(), scales.rowwise_cpu_dptr(), + grad_output.rowwise_cpu_dptr(), ref_output.get(), + ref_grad_input.get(), ref_grad_scales.get(), rows, hidden, interleave, + compute_grad_scales); + + switch (activation) { + case ScaledActivationCase::kSwiGLU: + nvte_scaled_swiglu(input.data(), scales.data(), output.data(), interleave, 0); + nvte_scaled_dswiglu(grad_output.data(), input.data(), scales.data(), grad_input.data(), + compute_grad_scales ? grad_scales.data() : nullptr, interleave, 0); + break; + case ScaledActivationCase::kClampedSwiGLU: + nvte_scaled_clamped_swiglu(input.data(), scales.data(), output.data(), kClampedLimit, + kClampedAlpha, kClampedLinearOffset, interleave, 0); + nvte_scaled_clamped_dswiglu( + grad_output.data(), input.data(), scales.data(), grad_input.data(), + compute_grad_scales ? grad_scales.data() : nullptr, kClampedLimit, kClampedAlpha, + kClampedLinearOffset, interleave, 0); + break; + case ScaledActivationCase::kSReLU: + nvte_scaled_srelu(input.data(), scales.data(), output.data(), 0); + nvte_scaled_dsrelu(grad_output.data(), input.data(), scales.data(), grad_input.data(), + compute_grad_scales ? grad_scales.data() : nullptr, 0); + break; + } + + NVTE_CHECK_CUDA(cudaDeviceSynchronize()); + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + auto [atol, rtol] = getTolerances(data_type); + if (data_type == DType::kFloat32) { + atol = 5e-5; + rtol = 5e-5; + } + compareResults("scaled_activation_output", output, ref_output.get(), true, atol, rtol); + compareResults("scaled_activation_grad_input", grad_input, ref_grad_input.get(), true, atol, + rtol); + if (compute_grad_scales) { + compareResults("scaled_activation_grad_scales", grad_scales, ref_grad_scales.get(), true, atol, + rtol); + } +} + +class ScaledActivationTest + : public ::testing::TestWithParam< + std::tuple, int64_t, + bool>> { +}; + +std::string test_name_generator( + const testing::TestParamInfo &info) { + const auto activation = std::get<0>(info.param); + const auto data_type = std::get<1>(info.param); + const auto scale_type = std::get<2>(info.param); + const auto shape = std::get<3>(info.param); + const auto interleave = std::get<4>(info.param); + const auto compute_grad_scales = std::get<5>(info.param); + return std::string(activation_name(activation)) + "_data_" + test::typeName(data_type) + + "_scale_" + test::typeName(scale_type) + "_m_" + std::to_string(shape.first) + "_h_" + + std::to_string(shape.second) + "_interleave_" + std::to_string(interleave) + + (compute_grad_scales ? "_with_scale_grad" : "_no_scale_grad"); +} + +} // namespace + +TEST_P(ScaledActivationTest, ForwardBackward) { + const auto activation = std::get<0>(GetParam()); + const auto data_type = std::get<1>(GetParam()); + const auto scale_type = std::get<2>(GetParam()); + const auto shape = std::get<3>(GetParam()); + const auto interleave = std::get<4>(GetParam()); + const auto compute_grad_scales = std::get<5>(GetParam()); + + if (activation == ScaledActivationCase::kSReLU && interleave != 0) { + GTEST_SKIP() << "SReLU is not a GLU activation."; + } + if (activation != ScaledActivationCase::kSReLU && interleave > 0 && + shape.second % static_cast(interleave) != 0) { + GTEST_SKIP() << "Hidden size must be divisible by GLU interleave."; + } + + using namespace test; + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(data_type, DataT, { + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(scale_type, ScaleT, { + run_scaled_activation_test(activation, shape.first, shape.second, interleave, + compute_grad_scales); + }); + }); +} + +// Test axes (the six tuple elements consumed by ScaledActivationTest): +// 1. Activation : SwiGLU and ClampedSwiGLU are gated (input is [M, 2H]); +// SReLU is unary (input is [M, H], no gate split). +// 2. Data dtype : dtype of the activation input/output tensors. +// 3. Scale dtype : dtype of act_scales / grad_act_scales. +// 4. Shape {rows, hidden}: rows = M (tokens), hidden = H (output width; gated input is 2H). +// 5. GLU interleave : 0 = contiguous [a | b]; 32 = interleaved a/b blocks. Only valid +// for gated activations with hidden % 32 == 0; SReLU skips != 0. +// 6. compute_grad_scales : whether the backward also reduces grad_act_scales. + +// Interleave is swept over {0, 32}; invalid combinations -- SReLU with any nonzero interleave, or +// a gated activation whose hidden is not divisible by the interleave -- are skipped at runtime by +// the GTEST_SKIP guards in the test body. +INSTANTIATE_TEST_SUITE_P( + OperatorTest_ScaledActivation, ScaledActivationTest, + ::testing::Combine( + ::testing::Values(ScaledActivationCase::kSwiGLU, ScaledActivationCase::kClampedSwiGLU, + ScaledActivationCase::kSReLU), + ::testing::Values(DType::kFloat32, DType::kBFloat16), // data dtype + ::testing::Values(DType::kFloat32, DType::kBFloat16), // scale dtype + ::testing::Values(std::pair{17, 64}, // odd rows, aligned hidden + std::pair{32, 32}, // minimal aligned square + std::pair{128, 128}, // square + std::pair{256, 64}, // many rows, narrow hidden + std::pair{1024, 2048}, // large FFN-ish width + std::pair{1, 1}, // single element + std::pair{1, 96}, // single row + std::pair{96, 1}, // single hidden column + std::pair{13, 100}), // non-power-of-two + ::testing::Values(0, 32), // contiguous + interleaved + ::testing::Values(false, true)), // grad_act_scales off / on + test_name_generator); diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 8f96432ed8..b4ba17e048 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -255,6 +255,7 @@ list(APPEND transformer_engine_cuda_arch_specific_sources activation/relu_dbias.cu activation/relu_grouped.cu activation/relu_grouped_dbias.cu + activation/scaled_activation.cu activation/swiglu.cu activation/swiglu_dbias.cu activation/swiglu_grouped.cu @@ -513,6 +514,7 @@ if (NVTE_BUILD_ACTIVATION_WITH_FAST_MATH) activation/relu_dbias.cu activation/relu_grouped.cu activation/relu_grouped_dbias.cu + activation/scaled_activation.cu activation/swiglu.cu activation/swiglu_dbias.cu activation/swiglu_grouped.cu diff --git a/transformer_engine/common/activation/scaled_activation.cu b/transformer_engine/common/activation/scaled_activation.cu new file mode 100644 index 0000000000..73df92338c --- /dev/null +++ b/transformer_engine/common/activation/scaled_activation.cu @@ -0,0 +1,781 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/* Scaled activations: apply an activation, multiply by a per-row scale + * (act_scales[row]), do all math in fp32, and cast once at the store. The + * backward path optionally also reduces the gradient of the per-row scale. + * + * The six __global__ kernels below: + * + * # | Kernel | Activation | Dir | grad_act_scales | Launch + * ---+-----------------------------------------------+------------------------+-----+-----------------+-------------------- + * 1 | scaled_gated_forward_kernel | SwiGLU / ClampedSwiGLU | fwd | -- | vectorized row segments + * 2 | scaled_srelu_forward_kernel | SReLU (unary) | fwd | -- | vectorized row grid + * 3 | scaled_gated_backward_kernel | SwiGLU / ClampedSwiGLU | bwd | no | vectorized row segments + * 4 | scaled_srelu_backward_kernel | SReLU | bwd | no | vectorized row grid + * 5 | scaled_gated_backward_with_scale_grad_kernel | SwiGLU / ClampedSwiGLU | bwd | yes | vectorized, one block per row + * 6 | scaled_srelu_backward_with_scale_grad_kernel | SReLU | bwd | yes | vectorized, one block per row + * + * The "with scale grad" variants compute grad_act_scales[row] = sum_j dY * unscaled, + * a per-row reduction that requires the one-block-per-row launch; when + * grad_act_scales is null the cheaper flat element-wise grid is used instead. + * + * Vectorization model: + * + * Gated activations consume two FC1 streams per row: an activation stream and a + * gate stream. With no GLU interleave, the row is laid out as: + * + * [ act[0:H] | gate[0:H] ] + * + * With GLU interleave, e.g. interleave=32, the row is laid out as independent + * act/gate segments: + * + * [ act[0:32] | gate[0:32] | act[32:64] | gate[32:64] | ... ] + * + * Vector loads: + * + * interleave=0: + * input [ act0 | act1 | ... | actN | gate0 | gate1 | ... | gateN ] + * | | + * v v + * load act vector i gate vector i + * store output vector i = activation(act vector i) * gate vector i * scale[row] + * + * interleave=32: + * input [ act0 | gate0 | act1 | gate1 | ... | actN | gateN ] + * | | | | + * v v v v + * load act0 gate0 act1 gate1 + * store output vector i = activation(act vector i) * gate vector i * scale[row] + * + * Only fully aligned segments use vector loads. Everything else uses the same + * kernels with nvec=1, i.e. regular elementwise loads/stores. + */ + +#include + +#include + +#include "../common.h" +#include "../util/math.h" +#include "../util/vectorized_pointwise.h" + +namespace transformer_engine { +namespace { + +enum class ScaledActivation { + kSwiGLU, + kClampedSwiGLU, + kSReLU, +}; + +__device__ __forceinline__ float sigmoid_from_float(const float x) { + return 1.0f / (1.0f + expf(-x)); +} + +template +__device__ __forceinline__ float gated_forward_value(const float act_in, const float gate_in, + const ClampedSwiGLUParam ¶m) { + if constexpr (Act == ScaledActivation::kSwiGLU) { + Empty empty = {}; + return silu(act_in, empty) * gate_in; + } else { + const float gate = fminf(fmaxf(-param.limit, gate_in), param.limit) + param.glu_linear_offset; + return clamped_silu(act_in, param) * gate; + } +} + +template +__device__ __forceinline__ void gated_backward_values(const float act_in, const float gate_in, + const ClampedSwiGLUParam ¶m, float *dact, + float *dgate, float *unscaled) { + if constexpr (Act == ScaledActivation::kSwiGLU) { + const float sigmoid = sigmoid_from_float(act_in); + const float act = act_in * sigmoid; + const float dact_base = sigmoid + act_in * sigmoid * (1.0f - sigmoid); + *unscaled = act * gate_in; + *dact = dact_base * gate_in; + *dgate = act; + } else { + const bool dgate_mask = gate_in <= param.limit && gate_in >= -param.limit; + const float gate = fminf(fmaxf(-param.limit, gate_in), param.limit) + param.glu_linear_offset; + const bool dact_mask = act_in <= param.limit; + const float clamped_act_in = fminf(act_in, param.limit); + const float sigmoid = sigmoid_from_float(param.alpha * clamped_act_in); + const float act = clamped_act_in * sigmoid; + const float dact_base = + dact_mask ? sigmoid + param.alpha * clamped_act_in * sigmoid * (1.0f - sigmoid) : 0.0f; + *unscaled = act * gate; + *dact = dact_base * gate; + *dgate = dgate_mask ? act : 0.0f; + } +} + +constexpr int kThreads = unary_kernel_threads; +constexpr int kReductionThreads = 256; +constexpr int kReductionWarps = kReductionThreads / THREADS_PER_WARP; + +__device__ __forceinline__ float warp_reduce_sum(float value) { +#pragma unroll + for (int offset = THREADS_PER_WARP / 2; offset > 0; offset >>= 1) { + value += __shfl_down_sync(0xffffffff, value, offset); + } + return value; +} + +__device__ __forceinline__ float block_reduce_sum(float value, float *smem) { + const int lane = threadIdx.x % THREADS_PER_WARP; + const int warp = threadIdx.x / THREADS_PER_WARP; + + value = warp_reduce_sum(value); + if (lane == 0) { + smem[warp] = value; + } + __syncthreads(); + + value = threadIdx.x < kReductionWarps ? smem[lane] : 0.0f; + if (warp == 0) { + value = warp_reduce_sum(value); + } + return value; +} + +template +constexpr int vector_width() { + return 32 / static_cast(sizeof(T)); +} + +inline int launch_blocks(const size_t work_items) { + return static_cast( + std::min(DIVUP(work_items, static_cast(kThreads)), 65535)); +} + +template +Alignment row_vector_alignment(const size_t lead_dim, const int nvec, const Ptrs... ptrs) { + if (nvec == 1) { + return Alignment::SAME_ALIGNED; + } + // GLU interleave is handled as independent row-local segments. Keep the scalar + // fallback for odd segment widths or unaligned pointers so vector stores never + // cross from an activation segment into its paired gate segment. + if (lead_dim % static_cast(nvec) != 0) { + return Alignment::DIFFERENT; + } + const auto align = CheckAlignment(lead_dim, nvec, ptrs...); + return align == Alignment::SAME_ALIGNED ? Alignment::SAME_ALIGNED : Alignment::DIFFERENT; +} + +template +__global__ void scaled_gated_forward_kernel(const InputT *input, const ScaleT *act_scales, + OutputT *output, const size_t rows, const size_t hidden, + const size_t segment_size, const size_t num_segments, + const size_t num_vectors_per_segment, + const ClampedSwiGLUParam param) { + const size_t total_vectors = rows * num_segments * num_vectors_per_segment; + for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < total_vectors; + tid += gridDim.x * blockDim.x) { + const size_t vector_idx = tid % num_vectors_per_segment; + const size_t segment = (tid / num_vectors_per_segment) % num_segments; + const size_t row = tid / (num_vectors_per_segment * num_segments); + const size_t input_segment_offset = row * hidden * 2 + segment * segment_size * 2; + const size_t output_segment_offset = row * hidden + segment * segment_size; + + VectorizedLoader act_loader(input + input_segment_offset, segment_size); + VectorizedLoader gate_loader(input + input_segment_offset + segment_size, + segment_size); + VectorizedStorer output_storer(output + output_segment_offset, + segment_size); + act_loader.load(vector_idx, segment_size); + gate_loader.load(vector_idx, segment_size); + const float scale = static_cast(act_scales[row]); +#pragma unroll + for (int lane = 0; lane < nvec; ++lane) { + const float unscaled = + gated_forward_value(static_cast(act_loader.separate()[lane]), + static_cast(gate_loader.separate()[lane]), param); + output_storer.separate()[lane] = static_cast(unscaled * scale); + } + output_storer.store(vector_idx, segment_size); + } +} + +template +__global__ void scaled_srelu_forward_kernel(const InputT *input, const ScaleT *act_scales, + OutputT *output, const size_t rows, const size_t hidden, + const size_t num_vectors_per_row) { + Empty empty = {}; + const size_t total_vectors = rows * num_vectors_per_row; + for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < total_vectors; + tid += gridDim.x * blockDim.x) { + const size_t vector_idx = tid % num_vectors_per_row; + const size_t row = tid / num_vectors_per_row; + VectorizedLoader input_loader(input + row * hidden, hidden); + VectorizedStorer output_storer(output + row * hidden, hidden); + input_loader.load(vector_idx, hidden); + const float scale = static_cast(act_scales[row]); +#pragma unroll + for (int lane = 0; lane < nvec; ++lane) { + const float unscaled = + srelu(static_cast(input_loader.separate()[lane]), empty); + output_storer.separate()[lane] = static_cast(unscaled * scale); + } + output_storer.store(vector_idx, hidden); + } +} + +template +__global__ void scaled_gated_backward_kernel(const GradT *grad_output, const InputT *input, + const ScaleT *act_scales, OutputT *grad_input, + const size_t rows, const size_t hidden, + const size_t segment_size, const size_t num_segments, + const size_t num_vectors_per_segment, + const ClampedSwiGLUParam param) { + const size_t total_vectors = rows * num_segments * num_vectors_per_segment; + for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < total_vectors; + tid += gridDim.x * blockDim.x) { + const size_t vector_idx = tid % num_vectors_per_segment; + const size_t segment = (tid / num_vectors_per_segment) % num_segments; + const size_t row = tid / (num_vectors_per_segment * num_segments); + const size_t input_segment_offset = row * hidden * 2 + segment * segment_size * 2; + const size_t output_segment_offset = row * hidden + segment * segment_size; + + VectorizedLoader grad_loader(grad_output + output_segment_offset, + segment_size); + VectorizedLoader act_loader(input + input_segment_offset, segment_size); + VectorizedLoader gate_loader(input + input_segment_offset + segment_size, + segment_size); + VectorizedStorer act_storer(grad_input + input_segment_offset, + segment_size); + VectorizedStorer gate_storer( + grad_input + input_segment_offset + segment_size, segment_size); + grad_loader.load(vector_idx, segment_size); + act_loader.load(vector_idx, segment_size); + gate_loader.load(vector_idx, segment_size); + const float scale = static_cast(act_scales[row]); +#pragma unroll + for (int lane = 0; lane < nvec; ++lane) { + float dact = 0.0f; + float dgate = 0.0f; + float unscaled = 0.0f; + gated_backward_values(static_cast(act_loader.separate()[lane]), + static_cast(gate_loader.separate()[lane]), param, &dact, + &dgate, &unscaled); + (void)unscaled; + const float grad = static_cast(grad_loader.separate()[lane]) * scale; + act_storer.separate()[lane] = static_cast(grad * dact); + gate_storer.separate()[lane] = static_cast(grad * dgate); + } + act_storer.store(vector_idx, segment_size); + gate_storer.store(vector_idx, segment_size); + } +} + +template +__global__ void scaled_srelu_backward_kernel(const GradT *grad_output, const InputT *input, + const ScaleT *act_scales, OutputT *grad_input, + const size_t rows, const size_t hidden, + const size_t num_vectors_per_row) { + Empty empty = {}; + const size_t total_vectors = rows * num_vectors_per_row; + for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < total_vectors; + tid += gridDim.x * blockDim.x) { + const size_t vector_idx = tid % num_vectors_per_row; + const size_t row = tid / num_vectors_per_row; + VectorizedLoader grad_loader(grad_output + row * hidden, hidden); + VectorizedLoader input_loader(input + row * hidden, hidden); + VectorizedStorer grad_input_storer(grad_input + row * hidden, hidden); + grad_loader.load(vector_idx, hidden); + input_loader.load(vector_idx, hidden); + const float scale = static_cast(act_scales[row]); +#pragma unroll + for (int lane = 0; lane < nvec; ++lane) { + const float grad = static_cast(grad_loader.separate()[lane]) * scale; + grad_input_storer.separate()[lane] = static_cast( + grad * dsrelu(static_cast(input_loader.separate()[lane]), empty)); + } + grad_input_storer.store(vector_idx, hidden); + } +} + +template +__global__ void scaled_gated_backward_with_scale_grad_kernel( + const GradT *grad_output, const InputT *input, const ScaleT *act_scales, OutputT *grad_input, + GradScaleT *grad_act_scales, const size_t rows, const size_t hidden, const size_t segment_size, + const size_t num_segments, const size_t num_vectors_per_segment, + const ClampedSwiGLUParam param) { + __shared__ float smem[kReductionWarps]; + const size_t row = blockIdx.x; + (void)rows; + float scale_grad = 0.0f; + const float scale = static_cast(act_scales[row]); + + // Flatten (segment, vector) so interleave=32 distributes all row work across + // the block instead of using only a few threads per small act/gate segment. + const size_t row_vectors = num_segments * num_vectors_per_segment; + for (size_t row_vector_idx = threadIdx.x; row_vector_idx < row_vectors; + row_vector_idx += blockDim.x) { + const size_t segment = row_vector_idx / num_vectors_per_segment; + const size_t vector_idx = row_vector_idx % num_vectors_per_segment; + const size_t input_segment_offset = row * hidden * 2 + segment * segment_size * 2; + const size_t output_segment_offset = row * hidden + segment * segment_size; + VectorizedLoader grad_loader(grad_output + output_segment_offset, + segment_size); + VectorizedLoader act_loader(input + input_segment_offset, segment_size); + VectorizedLoader gate_loader(input + input_segment_offset + segment_size, + segment_size); + VectorizedStorer act_storer(grad_input + input_segment_offset, + segment_size); + VectorizedStorer gate_storer( + grad_input + input_segment_offset + segment_size, segment_size); + + grad_loader.load(vector_idx, segment_size); + act_loader.load(vector_idx, segment_size); + gate_loader.load(vector_idx, segment_size); +#pragma unroll + for (int lane = 0; lane < nvec; ++lane) { + float dact = 0.0f; + float dgate = 0.0f; + float unscaled = 0.0f; + gated_backward_values(static_cast(act_loader.separate()[lane]), + static_cast(gate_loader.separate()[lane]), param, &dact, + &dgate, &unscaled); + const float grad = static_cast(grad_loader.separate()[lane]); + scale_grad += grad * unscaled; + + const float scaled_grad = grad * scale; + act_storer.separate()[lane] = static_cast(scaled_grad * dact); + gate_storer.separate()[lane] = static_cast(scaled_grad * dgate); + } + act_storer.store(vector_idx, segment_size); + gate_storer.store(vector_idx, segment_size); + } + + scale_grad = block_reduce_sum(scale_grad, smem); + if (threadIdx.x == 0) { + grad_act_scales[row] = static_cast(scale_grad); + } +} + +template +__global__ void scaled_srelu_backward_with_scale_grad_kernel( + const GradT *grad_output, const InputT *input, const ScaleT *act_scales, OutputT *grad_input, + GradScaleT *grad_act_scales, const size_t rows, const size_t hidden, + const size_t num_vectors_per_row) { + __shared__ float smem[kReductionWarps]; + const size_t row = blockIdx.x; + (void)rows; + float scale_grad = 0.0f; + Empty empty = {}; + const float scale = static_cast(act_scales[row]); + + VectorizedLoader grad_loader(grad_output + row * hidden, hidden); + VectorizedLoader input_loader(input + row * hidden, hidden); + VectorizedStorer grad_input_storer(grad_input + row * hidden, hidden); + for (size_t vector_idx = threadIdx.x; vector_idx < num_vectors_per_row; + vector_idx += blockDim.x) { + grad_loader.load(vector_idx, hidden); + input_loader.load(vector_idx, hidden); +#pragma unroll + for (int lane = 0; lane < nvec; ++lane) { + const float unscaled = + srelu(static_cast(input_loader.separate()[lane]), empty); + const float grad = static_cast(grad_loader.separate()[lane]); + scale_grad += grad * unscaled; + + const float scaled_grad = grad * scale; + const float dact = + dsrelu(static_cast(input_loader.separate()[lane]), empty); + grad_input_storer.separate()[lane] = static_cast(scaled_grad * dact); + } + grad_input_storer.store(vector_idx, hidden); + } + + scale_grad = block_reduce_sum(scale_grad, smem); + if (threadIdx.x == 0) { + grad_act_scales[row] = static_cast(scale_grad); + } +} + +void check_scale_tensor(const Tensor *act_scales, const size_t rows, const char *api_name) { + NVTE_CHECK(act_scales->numel() == rows, api_name, ": act_scales must have one value per row."); +} + +void check_gated_forward_tensors(const Tensor *input, const Tensor *act_scales, + const Tensor *output, const int64_t glu_interleave_size, + const char *api_name, size_t *rows, size_t *hidden) { + const auto input_dims = input->flat_2d_dims(); + const auto output_dims = output->flat_2d_dims(); + NVTE_CHECK(input_dims[0] == output_dims[0], api_name, ": input/output row mismatch."); + NVTE_CHECK(input_dims[1] == output_dims[1] * 2, api_name, + ": gated input last dimension must be twice output last dimension."); + NVTE_CHECK(glu_interleave_size >= 0, api_name, ": glu_interleave_size must be non-negative."); + if (glu_interleave_size > 0) { + NVTE_CHECK(glu_interleave_size % 32 == 0, api_name, + ": nonzero glu_interleave_size must be a multiple of 32."); + NVTE_CHECK(output_dims[1] % static_cast(glu_interleave_size) == 0, api_name, + ": output last dimension must be divisible by glu_interleave_size."); + } + check_scale_tensor(act_scales, input_dims[0], api_name); + *rows = input_dims[0]; + *hidden = output_dims[1]; +} + +void check_unary_forward_tensors(const Tensor *input, const Tensor *act_scales, + const Tensor *output, const char *api_name, size_t *rows, + size_t *hidden) { + const auto input_dims = input->flat_2d_dims(); + const auto output_dims = output->flat_2d_dims(); + NVTE_CHECK(input_dims[0] == output_dims[0] && input_dims[1] == output_dims[1], api_name, + ": input/output shapes must match."); + check_scale_tensor(act_scales, input_dims[0], api_name); + *rows = input_dims[0]; + *hidden = output_dims[1]; +} + +void check_grad_scale_tensor(const Tensor *grad_act_scales, const size_t rows, + const char *api_name) { + if (grad_act_scales != nullptr) { + NVTE_CHECK(grad_act_scales->numel() == rows, api_name, + ": grad_act_scales must have one value per row."); + } +} + +void check_gated_backward_tensors(const Tensor *grad_output, const Tensor *input, + const Tensor *act_scales, const Tensor *grad_input, + const Tensor *grad_act_scales, const int64_t glu_interleave_size, + const char *api_name, size_t *rows, size_t *hidden) { + const auto grad_dims = grad_output->flat_2d_dims(); + const auto input_dims = input->flat_2d_dims(); + const auto grad_input_dims = grad_input->flat_2d_dims(); + NVTE_CHECK(grad_dims[0] == input_dims[0] && input_dims[0] == grad_input_dims[0], api_name, + ": input/grad row mismatch."); + NVTE_CHECK(input_dims[1] == grad_dims[1] * 2 && grad_input_dims[1] == input_dims[1], api_name, + ": gated backward dimensions are inconsistent."); + NVTE_CHECK(glu_interleave_size >= 0, api_name, ": glu_interleave_size must be non-negative."); + if (glu_interleave_size > 0) { + NVTE_CHECK(glu_interleave_size % 32 == 0, api_name, + ": nonzero glu_interleave_size must be a multiple of 32."); + NVTE_CHECK(grad_dims[1] % static_cast(glu_interleave_size) == 0, api_name, + ": grad last dimension must be divisible by glu_interleave_size."); + } + check_scale_tensor(act_scales, input_dims[0], api_name); + check_grad_scale_tensor(grad_act_scales, input_dims[0], api_name); + *rows = input_dims[0]; + *hidden = grad_dims[1]; +} + +void check_unary_backward_tensors(const Tensor *grad_output, const Tensor *input, + const Tensor *act_scales, const Tensor *grad_input, + const Tensor *grad_act_scales, const char *api_name, size_t *rows, + size_t *hidden) { + const auto grad_dims = grad_output->flat_2d_dims(); + const auto input_dims = input->flat_2d_dims(); + const auto grad_input_dims = grad_input->flat_2d_dims(); + NVTE_CHECK(grad_dims[0] == input_dims[0] && input_dims[0] == grad_input_dims[0], api_name, + ": input/grad row mismatch."); + NVTE_CHECK(grad_dims[1] == input_dims[1] && input_dims[1] == grad_input_dims[1], api_name, + ": unary backward dimensions are inconsistent."); + check_scale_tensor(act_scales, input_dims[0], api_name); + check_grad_scale_tensor(grad_act_scales, input_dims[0], api_name); + *rows = input_dims[0]; + *hidden = grad_dims[1]; +} + +template +void launch_scaled_gated_forward(const NVTETensor nvte_input, const NVTETensor nvte_act_scales, + NVTETensor nvte_output, const int64_t glu_interleave_size, + const ClampedSwiGLUParam param, cudaStream_t stream, + const char *api_name) { + const Tensor *input = convertNVTETensorCheck(nvte_input); + const Tensor *act_scales = convertNVTETensorCheck(nvte_act_scales); + Tensor *output = convertNVTETensorCheck(nvte_output); + size_t rows = 0; + size_t hidden = 0; + check_gated_forward_tensors(input, act_scales, output, glu_interleave_size, api_name, &rows, + &hidden); + if (rows == 0 || hidden == 0) return; + + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(input->data.dtype, InputT, { + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(act_scales->data.dtype, ScaleT, { + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(output->data.dtype, OutputT, { + constexpr int nvec = sizeof(InputT) == sizeof(OutputT) ? vector_width() : 1; + const auto input_ptr = reinterpret_cast(input->data.dptr); + const auto scale_ptr = reinterpret_cast(act_scales->data.dptr); + auto output_ptr = reinterpret_cast(output->data.dptr); + const size_t segment_size = + glu_interleave_size > 0 ? static_cast(glu_interleave_size) : hidden; + const size_t num_segments = glu_interleave_size > 0 ? hidden / segment_size : 1; + const auto align = row_vector_alignment(segment_size, nvec, input_ptr, + input_ptr + segment_size, output_ptr); + const bool use_vector = align == Alignment::SAME_ALIGNED; + const size_t num_vectors = + use_vector ? get_num_aligned_elements(input_ptr, segment_size, nvec, sizeof(InputT)) + : segment_size; + const int blocks = launch_blocks(rows * num_segments * num_vectors); + if (use_vector) { + scaled_gated_forward_kernel + <<>>(input_ptr, scale_ptr, output_ptr, rows, hidden, + segment_size, num_segments, num_vectors, param); + } else { + scaled_gated_forward_kernel<1, InputT, ScaleT, OutputT, Act> + <<>>(input_ptr, scale_ptr, output_ptr, rows, hidden, + segment_size, num_segments, segment_size, param); + } + }); + }); + }); + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +void launch_scaled_srelu_forward(const NVTETensor nvte_input, const NVTETensor nvte_act_scales, + NVTETensor nvte_output, cudaStream_t stream, + const char *api_name) { + const Tensor *input = convertNVTETensorCheck(nvte_input); + const Tensor *act_scales = convertNVTETensorCheck(nvte_act_scales); + Tensor *output = convertNVTETensorCheck(nvte_output); + size_t rows = 0; + size_t hidden = 0; + check_unary_forward_tensors(input, act_scales, output, api_name, &rows, &hidden); + if (rows == 0 || hidden == 0) return; + + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(input->data.dtype, InputT, { + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(act_scales->data.dtype, ScaleT, { + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(output->data.dtype, OutputT, { + constexpr int nvec = sizeof(InputT) == sizeof(OutputT) ? vector_width() : 1; + const auto input_ptr = reinterpret_cast(input->data.dptr); + const auto scale_ptr = reinterpret_cast(act_scales->data.dptr); + auto output_ptr = reinterpret_cast(output->data.dptr); + const auto align = row_vector_alignment(hidden, nvec, input_ptr, output_ptr); + const bool use_vector = align == Alignment::SAME_ALIGNED; + const size_t num_vectors = + use_vector ? get_num_aligned_elements(input_ptr, hidden, nvec, sizeof(InputT)) : hidden; + const int blocks = launch_blocks(rows * num_vectors); + if (use_vector) { + scaled_srelu_forward_kernel + <<>>(input_ptr, scale_ptr, output_ptr, rows, hidden, + num_vectors); + } else { + scaled_srelu_forward_kernel<1, InputT, ScaleT, OutputT><<>>( + input_ptr, scale_ptr, output_ptr, rows, hidden, hidden); + } + }); + }); + }); + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +template +void launch_scaled_gated_backward(const NVTETensor nvte_grad_output, const NVTETensor nvte_input, + const NVTETensor nvte_act_scales, NVTETensor nvte_grad_input, + NVTETensor nvte_grad_act_scales, + const int64_t glu_interleave_size, const ClampedSwiGLUParam param, + cudaStream_t stream, const char *api_name) { + const Tensor *grad_output = convertNVTETensorCheck(nvte_grad_output); + const Tensor *input = convertNVTETensorCheck(nvte_input); + const Tensor *act_scales = convertNVTETensorCheck(nvte_act_scales); + Tensor *grad_input = convertNVTETensorCheck(nvte_grad_input); + Tensor *grad_act_scales = + nvte_grad_act_scales == nullptr ? nullptr : convertNVTETensorCheck(nvte_grad_act_scales); + size_t rows = 0; + size_t hidden = 0; + check_gated_backward_tensors(grad_output, input, act_scales, grad_input, grad_act_scales, + glu_interleave_size, api_name, &rows, &hidden); + if (rows == 0 || hidden == 0) return; + + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(grad_output->data.dtype, GradT, { + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(input->data.dtype, InputT, { + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(act_scales->data.dtype, ScaleT, { + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(grad_input->data.dtype, OutputT, { + constexpr int nvec = sizeof(GradT) == sizeof(InputT) && sizeof(InputT) == sizeof(OutputT) + ? vector_width() + : 1; + const auto grad_ptr = reinterpret_cast(grad_output->data.dptr); + const auto input_ptr = reinterpret_cast(input->data.dptr); + const auto scale_ptr = reinterpret_cast(act_scales->data.dptr); + auto grad_input_ptr = reinterpret_cast(grad_input->data.dptr); + const size_t segment_size = + glu_interleave_size > 0 ? static_cast(glu_interleave_size) : hidden; + const size_t num_segments = glu_interleave_size > 0 ? hidden / segment_size : 1; + const auto align = row_vector_alignment(segment_size, nvec, grad_ptr, input_ptr, + input_ptr + segment_size, grad_input_ptr, + grad_input_ptr + segment_size); + const bool use_vector = align == Alignment::SAME_ALIGNED; + const size_t num_vectors = + use_vector ? get_num_aligned_elements(input_ptr, segment_size, nvec, sizeof(InputT)) + : segment_size; + if (grad_act_scales == nullptr) { + const int blocks = launch_blocks(rows * num_segments * num_vectors); + if (use_vector) { + scaled_gated_backward_kernel + <<>>(grad_ptr, input_ptr, scale_ptr, grad_input_ptr, + rows, hidden, segment_size, num_segments, + num_vectors, param); + } else { + scaled_gated_backward_kernel<1, GradT, InputT, ScaleT, OutputT, Act> + <<>>(grad_ptr, input_ptr, scale_ptr, grad_input_ptr, + rows, hidden, segment_size, num_segments, + segment_size, param); + } + } else { + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(grad_act_scales->data.dtype, GradScaleT, { + auto grad_act_scales_ptr = reinterpret_cast(grad_act_scales->data.dptr); + if (use_vector) { + scaled_gated_backward_with_scale_grad_kernel + <<(rows), kReductionThreads, 0, stream>>>( + grad_ptr, input_ptr, scale_ptr, grad_input_ptr, grad_act_scales_ptr, rows, + hidden, segment_size, num_segments, num_vectors, param); + } else { + scaled_gated_backward_with_scale_grad_kernel<1, GradT, InputT, ScaleT, OutputT, + GradScaleT, Act> + <<(rows), kReductionThreads, 0, stream>>>( + grad_ptr, input_ptr, scale_ptr, grad_input_ptr, grad_act_scales_ptr, rows, + hidden, segment_size, num_segments, segment_size, param); + } + }); + } + }); + }); + }); + }); + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +void launch_scaled_srelu_backward(const NVTETensor nvte_grad_output, const NVTETensor nvte_input, + const NVTETensor nvte_act_scales, NVTETensor nvte_grad_input, + NVTETensor nvte_grad_act_scales, cudaStream_t stream, + const char *api_name) { + const Tensor *grad_output = convertNVTETensorCheck(nvte_grad_output); + const Tensor *input = convertNVTETensorCheck(nvte_input); + const Tensor *act_scales = convertNVTETensorCheck(nvte_act_scales); + Tensor *grad_input = convertNVTETensorCheck(nvte_grad_input); + Tensor *grad_act_scales = + nvte_grad_act_scales == nullptr ? nullptr : convertNVTETensorCheck(nvte_grad_act_scales); + size_t rows = 0; + size_t hidden = 0; + check_unary_backward_tensors(grad_output, input, act_scales, grad_input, grad_act_scales, + api_name, &rows, &hidden); + if (rows == 0 || hidden == 0) return; + + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(grad_output->data.dtype, GradT, { + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(input->data.dtype, InputT, { + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(act_scales->data.dtype, ScaleT, { + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(grad_input->data.dtype, OutputT, { + constexpr int nvec = sizeof(GradT) == sizeof(InputT) && sizeof(InputT) == sizeof(OutputT) + ? vector_width() + : 1; + const auto grad_ptr = reinterpret_cast(grad_output->data.dptr); + const auto input_ptr = reinterpret_cast(input->data.dptr); + const auto scale_ptr = reinterpret_cast(act_scales->data.dptr); + auto grad_input_ptr = reinterpret_cast(grad_input->data.dptr); + const auto align = + row_vector_alignment(hidden, nvec, grad_ptr, input_ptr, grad_input_ptr); + const bool use_vector = align == Alignment::SAME_ALIGNED; + const size_t num_vectors = + use_vector ? get_num_aligned_elements(input_ptr, hidden, nvec, sizeof(InputT)) + : hidden; + if (grad_act_scales == nullptr) { + const int blocks = launch_blocks(rows * num_vectors); + if (use_vector) { + scaled_srelu_backward_kernel + <<>>(grad_ptr, input_ptr, scale_ptr, grad_input_ptr, + rows, hidden, num_vectors); + } else { + scaled_srelu_backward_kernel<1, GradT, InputT, ScaleT, OutputT> + <<>>(grad_ptr, input_ptr, scale_ptr, grad_input_ptr, + rows, hidden, hidden); + } + } else { + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(grad_act_scales->data.dtype, GradScaleT, { + auto grad_act_scales_ptr = reinterpret_cast(grad_act_scales->data.dptr); + if (use_vector) { + scaled_srelu_backward_with_scale_grad_kernel + <<(rows), kReductionThreads, 0, stream>>>( + grad_ptr, input_ptr, scale_ptr, grad_input_ptr, grad_act_scales_ptr, rows, + hidden, num_vectors); + } else { + scaled_srelu_backward_with_scale_grad_kernel<1, GradT, InputT, ScaleT, OutputT, + GradScaleT> + <<(rows), kReductionThreads, 0, stream>>>( + grad_ptr, input_ptr, scale_ptr, grad_input_ptr, grad_act_scales_ptr, rows, + hidden, hidden); + } + }); + } + }); + }); + }); + }); + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +} // namespace +} // namespace transformer_engine + +void nvte_scaled_swiglu(const NVTETensor input, const NVTETensor act_scales, NVTETensor output, + int64_t glu_interleave_size, cudaStream_t stream) { + NVTE_API_CALL(nvte_scaled_swiglu); + using namespace transformer_engine; + Empty empty = {}; + (void)empty; + ClampedSwiGLUParam param = {}; + launch_scaled_gated_forward( + input, act_scales, output, glu_interleave_size, param, stream, "nvte_scaled_swiglu"); +} + +void nvte_scaled_dswiglu(const NVTETensor grad, const NVTETensor input, const NVTETensor act_scales, + NVTETensor grad_input, NVTETensor grad_act_scales, + int64_t glu_interleave_size, cudaStream_t stream) { + NVTE_API_CALL(nvte_scaled_dswiglu); + using namespace transformer_engine; + ClampedSwiGLUParam param = {}; + launch_scaled_gated_backward(grad, input, act_scales, grad_input, + grad_act_scales, glu_interleave_size, + param, stream, "nvte_scaled_dswiglu"); +} + +void nvte_scaled_clamped_swiglu(const NVTETensor input, const NVTETensor act_scales, + NVTETensor output, float limit, float alpha, + float glu_linear_offset, int64_t glu_interleave_size, + cudaStream_t stream) { + NVTE_API_CALL(nvte_scaled_clamped_swiglu); + using namespace transformer_engine; + ClampedSwiGLUParam param = {limit, alpha, glu_linear_offset}; + launch_scaled_gated_forward( + input, act_scales, output, glu_interleave_size, param, stream, "nvte_scaled_clamped_swiglu"); +} + +void nvte_scaled_clamped_dswiglu(const NVTETensor grad, const NVTETensor input, + const NVTETensor act_scales, NVTETensor grad_input, + NVTETensor grad_act_scales, float limit, float alpha, + float glu_linear_offset, int64_t glu_interleave_size, + cudaStream_t stream) { + NVTE_API_CALL(nvte_scaled_clamped_dswiglu); + using namespace transformer_engine; + ClampedSwiGLUParam param = {limit, alpha, glu_linear_offset}; + launch_scaled_gated_backward( + grad, input, act_scales, grad_input, grad_act_scales, glu_interleave_size, param, stream, + "nvte_scaled_clamped_dswiglu"); +} + +void nvte_scaled_srelu(const NVTETensor input, const NVTETensor act_scales, NVTETensor output, + cudaStream_t stream) { + NVTE_API_CALL(nvte_scaled_srelu); + using namespace transformer_engine; + launch_scaled_srelu_forward(input, act_scales, output, stream, "nvte_scaled_srelu"); +} + +void nvte_scaled_dsrelu(const NVTETensor grad, const NVTETensor input, const NVTETensor act_scales, + NVTETensor grad_input, NVTETensor grad_act_scales, cudaStream_t stream) { + NVTE_API_CALL(nvte_scaled_dsrelu); + using namespace transformer_engine; + launch_scaled_srelu_backward(grad, input, act_scales, grad_input, grad_act_scales, stream, + "nvte_scaled_dsrelu"); +} diff --git a/transformer_engine/common/include/transformer_engine/activation.h b/transformer_engine/common/include/transformer_engine/activation.h index 4ed083740d..ed90428f8c 100644 --- a/transformer_engine/common/include/transformer_engine/activation.h +++ b/transformer_engine/common/include/transformer_engine/activation.h @@ -368,6 +368,41 @@ void nvte_clamped_swiglu(const NVTETensor input, NVTETensor output, float limit, void nvte_clamped_swiglu_v2(const NVTETensor input, NVTETensor output, float limit, float alpha, float glu_linear_offset, cudaStream_t stream); +/*! \brief Computes ScaledSwiGLU without materializing GLU deinterleave. + * + * Computes output = SwiGLU(input) * act_scales[:, None]. + * If glu_interleave_size > 0, input is interpreted as interleaved + * [activation_block, linear_block] chunks of that size. + * + * \param[in] input Input tensor of shape [N, H * 2]. + * \param[in] act_scales Row-wise activation scales of shape [N]. + * \param[in,out] output Output tensor of shape [N, H]. + * \param[in] glu_interleave_size GLU interleave chunk size, or 0 for non-interleaved layout. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_scaled_swiglu(const NVTETensor input, const NVTETensor act_scales, NVTETensor output, + int64_t glu_interleave_size, cudaStream_t stream); + +/*! \brief Computes ScaledClampedSwiGLU without materializing GLU deinterleave. + * + * Computes output = ClampedSwiGLU(input) * act_scales[:, None]. + * This uses the same clamping, alpha, and linear-offset semantics as + * nvte_clamped_swiglu_v2. + * + * \param[in] input Input tensor of shape [N, H * 2]. + * \param[in] act_scales Row-wise activation scales of shape [N]. + * \param[in,out] output Output tensor of shape [N, H]. + * \param[in] limit Clipping limit. + * \param[in] alpha Activation sigmoid alpha. + * \param[in] glu_linear_offset Offset added to linear component after clamping. + * \param[in] glu_interleave_size GLU interleave chunk size, or 0 for non-interleaved layout. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_scaled_clamped_swiglu(const NVTETensor input, const NVTETensor act_scales, + NVTETensor output, float limit, float alpha, + float glu_linear_offset, int64_t glu_interleave_size, + cudaStream_t stream); + /*! \brief Computes the gated ReLU activation of the input. * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. @@ -473,6 +508,45 @@ void nvte_clamped_dswiglu_v2(const NVTETensor grad, const NVTETensor input, NVTE float limit, float alpha, float glu_linear_offset, cudaStream_t stream); +/*! \brief Computes ScaledSwiGLU backward without materializing GLU deinterleave. + * + * The optional grad_act_scales tensor may be null. When present, it receives + * sum(dY * SwiGLU(input), dim=-1). + * + * \param[in] grad Incoming gradient of shape [N, H]. + * \param[in] input Forward input tensor of shape [N, H * 2]. + * \param[in] act_scales Row-wise activation scales of shape [N]. + * \param[in,out] grad_input Outgoing gradient of shape [N, H * 2]. + * \param[in,out] grad_act_scales Optional row-wise scale gradient of shape [N], or null. + * \param[in] glu_interleave_size GLU interleave chunk size, or 0 for non-interleaved layout. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_scaled_dswiglu(const NVTETensor grad, const NVTETensor input, const NVTETensor act_scales, + NVTETensor grad_input, NVTETensor grad_act_scales, + int64_t glu_interleave_size, cudaStream_t stream); + +/*! \brief Computes ScaledClampedSwiGLU backward without materializing GLU deinterleave. + * + * The optional grad_act_scales tensor may be null. When present, it receives + * sum(dY * ClampedSwiGLU(input), dim=-1). + * + * \param[in] grad Incoming gradient of shape [N, H]. + * \param[in] input Forward input tensor of shape [N, H * 2]. + * \param[in] act_scales Row-wise activation scales of shape [N]. + * \param[in,out] grad_input Outgoing gradient of shape [N, H * 2]. + * \param[in,out] grad_act_scales Optional row-wise scale gradient of shape [N], or null. + * \param[in] limit Clipping limit. + * \param[in] alpha Activation sigmoid alpha. + * \param[in] glu_linear_offset Offset added to linear component after clamping. + * \param[in] glu_interleave_size GLU interleave chunk size, or 0 for non-interleaved layout. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_scaled_clamped_dswiglu(const NVTETensor grad, const NVTETensor input, + const NVTETensor act_scales, NVTETensor grad_input, + NVTETensor grad_act_scales, float limit, float alpha, + float glu_linear_offset, int64_t glu_interleave_size, + cudaStream_t stream); + /*! \brief Computes the gated ReLU activation gradient. * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. @@ -509,6 +583,33 @@ void nvte_dqgeglu(const NVTETensor grad, const NVTETensor input, NVTETensor outp void nvte_dsreglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Computes ScaledSReLU. + * + * Computes output = SReLU(input) * act_scales[:, None]. + * + * \param[in] input Input tensor for activation. + * \param[in] act_scales Row-wise activation scales of shape [N]. + * \param[in,out] output Output tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_scaled_srelu(const NVTETensor input, const NVTETensor act_scales, NVTETensor output, + cudaStream_t stream); + +/*! \brief Computes ScaledSReLU backward. + * + * The optional grad_act_scales tensor may be null. When present, it receives + * sum(dY * SReLU(input), dim=-1). + * + * \param[in] grad Incoming gradient. + * \param[in] input Forward input tensor. + * \param[in] act_scales Row-wise activation scales of shape [N]. + * \param[in,out] grad_input Outgoing input gradient. + * \param[in,out] grad_act_scales Optional row-wise scale gradient of shape [N], or null. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_scaled_dsrelu(const NVTETensor grad, const NVTETensor input, const NVTETensor act_scales, + NVTETensor grad_input, NVTETensor grad_act_scales, cudaStream_t stream); + #ifdef __cplusplus } // extern "C" #endif