-
Notifications
You must be signed in to change notification settings - Fork 752
[Common] Support scaled & clamped swiglu, srelu for BF16 #3132
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
zhongbozhu
wants to merge
6
commits into
NVIDIA:main
Choose a base branch
from
zhongbozhu:add_support_fused_swiglu
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+1,206
−0
Open
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
18d4d2c
support scaled swiglu, scaled srelu and scaled clamp swiglu
zhongbozhu 953c469
vectorized loading improvement
zhongbozhu e3ae293
fix bug for backward kernel
zhongbozhu 84cbdec
optimize
zhongbozhu c73c8ea
fix unit test failure
zhongbozhu 3eb18a6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,321 @@ | ||
| /************************************************************************* | ||
| * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| * | ||
| * See LICENSE for license information. | ||
| ************************************************************************/ | ||
|
|
||
| #include <algorithm> | ||
| #include <cmath> | ||
| #include <memory> | ||
| #include <string> | ||
| #include <tuple> | ||
|
|
||
| #include <cuda_runtime.h> | ||
| #include <gtest/gtest.h> | ||
|
|
||
| #include <transformer_engine/activation.h> | ||
|
|
||
| #include "../test_common.h" | ||
|
|
||
| using namespace transformer_engine; | ||
|
|
||
| namespace { | ||
|
|
||
| enum class ScaledActivationCase { | ||
| kSwiGLU, | ||
| kClampedSwiGLU, | ||
| kSReLU, | ||
| }; | ||
|
|
||
| constexpr float kClampedLimit = 1.3f; | ||
| constexpr float kClampedAlpha = 1.702f; | ||
| constexpr float kClampedLinearOffset = 0.5f; | ||
|
|
||
| const char *activation_name(ScaledActivationCase activation) { | ||
| switch (activation) { | ||
| case ScaledActivationCase::kSwiGLU: | ||
| return "scaled_swiglu"; | ||
| case ScaledActivationCase::kClampedSwiGLU: | ||
| return "scaled_clamped_swiglu"; | ||
| case ScaledActivationCase::kSReLU: | ||
| return "scaled_srelu"; | ||
| } | ||
| return "unknown"; | ||
| } | ||
|
|
||
| inline float sigmoid(const float x) { return 1.0f / (1.0f + expf(-x)); } | ||
|
|
||
| inline float qgelu_alpha(const float x, const float alpha) { return x * sigmoid(alpha * x); } | ||
|
|
||
| inline float dqgelu_alpha(const float x, const float alpha) { | ||
| const float sig = sigmoid(alpha * x); | ||
| return alpha * x * sig * (1.0f - sig) + sig; | ||
| } | ||
|
|
||
| inline float silu_ref(const float x) { return x * sigmoid(x); } | ||
|
|
||
| inline float dsilu_ref(const float x) { | ||
| const float sig = sigmoid(x); | ||
| return x * sig * (1.0f - sig) + sig; | ||
| } | ||
|
|
||
| inline float srelu_ref(const float x) { return x > 0.0f ? x * x : 0.0f; } | ||
|
|
||
| inline float dsrelu_ref(const float x) { return fmaxf(0.0f, 2.0f * x); } | ||
|
|
||
| inline void glu_indices(const size_t row, const size_t col, const size_t hidden, | ||
| const int64_t interleave, size_t *act_idx, size_t *linear_idx) { | ||
| if (interleave > 0) { | ||
| const size_t block = col / static_cast<size_t>(interleave); | ||
| const size_t lane = col % static_cast<size_t>(interleave); | ||
| const size_t base = row * hidden * 2 + block * static_cast<size_t>(interleave) * 2 + lane; | ||
| *act_idx = base; | ||
| *linear_idx = base + static_cast<size_t>(interleave); | ||
| } else { | ||
| const size_t base = row * hidden * 2; | ||
| *act_idx = base + col; | ||
| *linear_idx = base + hidden + col; | ||
| } | ||
| } | ||
|
|
||
| inline float gated_unscaled(const ScaledActivationCase activation, const float act_in, | ||
| const float linear_in) { | ||
| switch (activation) { | ||
| case ScaledActivationCase::kSwiGLU: | ||
| return silu_ref(act_in) * linear_in; | ||
| case ScaledActivationCase::kClampedSwiGLU: { | ||
| const float act = qgelu_alpha(fminf(kClampedLimit, act_in), kClampedAlpha); | ||
| const float linear = | ||
| fminf(fmaxf(-kClampedLimit, linear_in), kClampedLimit) + kClampedLinearOffset; | ||
| return act * linear; | ||
| } | ||
| case ScaledActivationCase::kSReLU: | ||
| return srelu_ref(act_in); | ||
| } | ||
| return 0.0f; | ||
| } | ||
|
|
||
| inline void gated_grads(const ScaledActivationCase activation, const float act_in, | ||
| const float linear_in, float *dact, float *dlinear, float *unscaled) { | ||
| switch (activation) { | ||
| case ScaledActivationCase::kSwiGLU: { | ||
| const float act = silu_ref(act_in); | ||
| *unscaled = act * linear_in; | ||
| *dact = dsilu_ref(act_in) * linear_in; | ||
| *dlinear = act; | ||
| return; | ||
| } | ||
| case ScaledActivationCase::kClampedSwiGLU: { | ||
| const bool dlinear_mask = linear_in <= kClampedLimit && linear_in >= -kClampedLimit; | ||
| const float act = qgelu_alpha(fminf(kClampedLimit, act_in), kClampedAlpha); | ||
| const float dact_base = | ||
| act_in <= kClampedLimit ? dqgelu_alpha(fminf(kClampedLimit, act_in), kClampedAlpha) | ||
| : 0.0f; | ||
| const float linear = | ||
| fminf(fmaxf(-kClampedLimit, linear_in), kClampedLimit) + kClampedLinearOffset; | ||
| *unscaled = act * linear; | ||
| *dact = dact_base * linear; | ||
| *dlinear = dlinear_mask ? act : 0.0f; | ||
| return; | ||
| } | ||
| case ScaledActivationCase::kSReLU: | ||
| *unscaled = srelu_ref(act_in); | ||
| *dact = dsrelu_ref(act_in); | ||
| *dlinear = 0.0f; | ||
| return; | ||
| } | ||
| } | ||
|
|
||
| template <typename DataT, typename ScaleT> | ||
| void compute_reference(ScaledActivationCase activation, const DataT *input, const ScaleT *scales, | ||
| const DataT *grad_output, DataT *output, DataT *grad_input, | ||
| DataT *grad_scales, const size_t rows, const size_t hidden, | ||
| const int64_t interleave, const bool compute_grad_scales) { | ||
| const bool is_gated = activation != ScaledActivationCase::kSReLU; | ||
| const size_t input_cols = is_gated ? hidden * 2 : hidden; | ||
| std::fill(grad_input, grad_input + rows * input_cols, static_cast<DataT>(0.0f)); | ||
|
|
||
| for (size_t row = 0; row < rows; ++row) { | ||
| const float scale = static_cast<float>(scales[row]); | ||
| float scale_grad = 0.0f; | ||
| for (size_t col = 0; col < hidden; ++col) { | ||
| const size_t out_idx = row * hidden + col; | ||
| float unscaled = 0.0f; | ||
| float dact = 0.0f; | ||
| float dlinear = 0.0f; | ||
| if (is_gated) { | ||
| size_t act_idx = 0; | ||
| size_t linear_idx = 0; | ||
| glu_indices(row, col, hidden, interleave, &act_idx, &linear_idx); | ||
| const float act_in = static_cast<float>(input[act_idx]); | ||
| const float linear_in = static_cast<float>(input[linear_idx]); | ||
| unscaled = gated_unscaled(activation, act_in, linear_in); | ||
| gated_grads(activation, act_in, linear_in, &dact, &dlinear, &unscaled); | ||
|
|
||
| const float scaled_grad = static_cast<float>(grad_output[out_idx]) * scale; | ||
| grad_input[act_idx] = static_cast<DataT>(scaled_grad * dact); | ||
| grad_input[linear_idx] = static_cast<DataT>(scaled_grad * dlinear); | ||
| } else { | ||
| const float x = static_cast<float>(input[out_idx]); | ||
| unscaled = srelu_ref(x); | ||
| const float scaled_grad = static_cast<float>(grad_output[out_idx]) * scale; | ||
| grad_input[out_idx] = static_cast<DataT>(scaled_grad * dsrelu_ref(x)); | ||
| } | ||
|
|
||
| output[out_idx] = static_cast<DataT>(unscaled * scale); | ||
| scale_grad += static_cast<float>(grad_output[out_idx]) * unscaled; | ||
| } | ||
| if (compute_grad_scales) { | ||
| grad_scales[row] = static_cast<DataT>(scale_grad); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| template <typename DataT, typename ScaleT> | ||
| void run_scaled_activation_test(ScaledActivationCase activation, const size_t rows, | ||
| const size_t hidden, const int64_t interleave, | ||
| const bool compute_grad_scales) { | ||
| using namespace test; | ||
| const DType data_type = TypeInfo<DataT>::dtype; | ||
| const DType scale_type = TypeInfo<ScaleT>::dtype; | ||
| const bool is_gated = activation != ScaledActivationCase::kSReLU; | ||
| const size_t input_cols = is_gated ? hidden * 2 : hidden; | ||
|
|
||
| Tensor input("input", std::vector<size_t>{rows, input_cols}, data_type); | ||
| Tensor scales("act_scales", std::vector<size_t>{rows}, scale_type); | ||
| Tensor output("output", std::vector<size_t>{rows, hidden}, data_type); | ||
| Tensor grad_output("grad_output", std::vector<size_t>{rows, hidden}, data_type); | ||
| Tensor grad_input("grad_input", std::vector<size_t>{rows, input_cols}, data_type); | ||
| Tensor grad_scales("grad_scales", std::vector<size_t>{rows}, data_type); | ||
|
|
||
| fillUniform(&input); | ||
| fillUniform(&scales); | ||
| fillUniform(&grad_output); | ||
|
|
||
| std::unique_ptr<DataT[]> ref_output = std::make_unique<DataT[]>(rows * hidden); | ||
| std::unique_ptr<DataT[]> ref_grad_input = std::make_unique<DataT[]>(rows * input_cols); | ||
| std::unique_ptr<DataT[]> ref_grad_scales = std::make_unique<DataT[]>(rows); | ||
|
|
||
| compute_reference(activation, input.rowwise_cpu_dptr<DataT>(), scales.rowwise_cpu_dptr<ScaleT>(), | ||
| grad_output.rowwise_cpu_dptr<DataT>(), ref_output.get(), | ||
| ref_grad_input.get(), ref_grad_scales.get(), rows, hidden, interleave, | ||
| compute_grad_scales); | ||
|
|
||
| switch (activation) { | ||
| case ScaledActivationCase::kSwiGLU: | ||
| nvte_scaled_swiglu(input.data(), scales.data(), output.data(), interleave, 0); | ||
| nvte_scaled_dswiglu(grad_output.data(), input.data(), scales.data(), grad_input.data(), | ||
| compute_grad_scales ? grad_scales.data() : nullptr, interleave, 0); | ||
| break; | ||
| case ScaledActivationCase::kClampedSwiGLU: | ||
| nvte_scaled_clamped_swiglu(input.data(), scales.data(), output.data(), kClampedLimit, | ||
| kClampedAlpha, kClampedLinearOffset, interleave, 0); | ||
| nvte_scaled_clamped_dswiglu( | ||
| grad_output.data(), input.data(), scales.data(), grad_input.data(), | ||
| compute_grad_scales ? grad_scales.data() : nullptr, kClampedLimit, kClampedAlpha, | ||
| kClampedLinearOffset, interleave, 0); | ||
| break; | ||
| case ScaledActivationCase::kSReLU: | ||
| nvte_scaled_srelu(input.data(), scales.data(), output.data(), 0); | ||
| nvte_scaled_dsrelu(grad_output.data(), input.data(), scales.data(), grad_input.data(), | ||
| compute_grad_scales ? grad_scales.data() : nullptr, 0); | ||
| break; | ||
| } | ||
|
|
||
| NVTE_CHECK_CUDA(cudaDeviceSynchronize()); | ||
| auto err = cudaGetLastError(); | ||
| ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); | ||
|
|
||
| auto [atol, rtol] = getTolerances(data_type); | ||
| if (data_type == DType::kFloat32) { | ||
| atol = 5e-5; | ||
| rtol = 5e-5; | ||
| } | ||
| compareResults("scaled_activation_output", output, ref_output.get(), true, atol, rtol); | ||
| compareResults("scaled_activation_grad_input", grad_input, ref_grad_input.get(), true, atol, | ||
| rtol); | ||
| if (compute_grad_scales) { | ||
| compareResults("scaled_activation_grad_scales", grad_scales, ref_grad_scales.get(), true, atol, | ||
| rtol); | ||
| } | ||
| } | ||
|
|
||
| class ScaledActivationTest | ||
| : public ::testing::TestWithParam< | ||
| std::tuple<ScaledActivationCase, DType, DType, std::pair<size_t, size_t>, int64_t, | ||
| bool>> { | ||
| }; | ||
|
|
||
| std::string test_name_generator( | ||
| const testing::TestParamInfo<ScaledActivationTest::ParamType> &info) { | ||
| const auto activation = std::get<0>(info.param); | ||
| const auto data_type = std::get<1>(info.param); | ||
| const auto scale_type = std::get<2>(info.param); | ||
| const auto shape = std::get<3>(info.param); | ||
| const auto interleave = std::get<4>(info.param); | ||
| const auto compute_grad_scales = std::get<5>(info.param); | ||
| return std::string(activation_name(activation)) + "_data_" + test::typeName(data_type) + | ||
| "_scale_" + test::typeName(scale_type) + "_m_" + std::to_string(shape.first) + "_h_" + | ||
| std::to_string(shape.second) + "_interleave_" + std::to_string(interleave) + | ||
| (compute_grad_scales ? "_with_scale_grad" : "_no_scale_grad"); | ||
| } | ||
|
|
||
| } // namespace | ||
|
|
||
| TEST_P(ScaledActivationTest, ForwardBackward) { | ||
| const auto activation = std::get<0>(GetParam()); | ||
| const auto data_type = std::get<1>(GetParam()); | ||
| const auto scale_type = std::get<2>(GetParam()); | ||
| const auto shape = std::get<3>(GetParam()); | ||
| const auto interleave = std::get<4>(GetParam()); | ||
| const auto compute_grad_scales = std::get<5>(GetParam()); | ||
|
|
||
| if (activation == ScaledActivationCase::kSReLU && interleave != 0) { | ||
| GTEST_SKIP() << "SReLU is not a GLU activation."; | ||
| } | ||
| if (activation != ScaledActivationCase::kSReLU && interleave > 0 && | ||
| shape.second % static_cast<size_t>(interleave) != 0) { | ||
| GTEST_SKIP() << "Hidden size must be divisible by GLU interleave."; | ||
| } | ||
|
|
||
| using namespace test; | ||
| TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(data_type, DataT, { | ||
| TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(scale_type, ScaleT, { | ||
| run_scaled_activation_test<DataT, ScaleT>(activation, shape.first, shape.second, interleave, | ||
| compute_grad_scales); | ||
| }); | ||
| }); | ||
| } | ||
|
|
||
| // Test axes (the six tuple elements consumed by ScaledActivationTest): | ||
| // 1. Activation : SwiGLU and ClampedSwiGLU are gated (input is [M, 2H]); | ||
| // SReLU is unary (input is [M, H], no gate split). | ||
| // 2. Data dtype : dtype of the activation input/output tensors. | ||
| // 3. Scale dtype : dtype of act_scales / grad_act_scales. | ||
| // 4. Shape {rows, hidden}: rows = M (tokens), hidden = H (output width; gated input is 2H). | ||
| // 5. GLU interleave : 0 = contiguous [a | b]; 32 = interleaved a/b blocks. Only valid | ||
| // for gated activations with hidden % 32 == 0; SReLU skips != 0. | ||
| // 6. compute_grad_scales : whether the backward also reduces grad_act_scales. | ||
|
|
||
| // Interleave is swept over {0, 32}; invalid combinations -- SReLU with any nonzero interleave, or | ||
| // a gated activation whose hidden is not divisible by the interleave -- are skipped at runtime by | ||
| // the GTEST_SKIP guards in the test body. | ||
| INSTANTIATE_TEST_SUITE_P( | ||
| OperatorTest_ScaledActivation, ScaledActivationTest, | ||
| ::testing::Combine( | ||
| ::testing::Values(ScaledActivationCase::kSwiGLU, ScaledActivationCase::kClampedSwiGLU, | ||
| ScaledActivationCase::kSReLU), | ||
| ::testing::Values(DType::kFloat32, DType::kBFloat16), // data dtype | ||
| ::testing::Values(DType::kFloat32, DType::kBFloat16), // scale dtype | ||
| ::testing::Values(std::pair<size_t, size_t>{17, 64}, // odd rows, aligned hidden | ||
| std::pair<size_t, size_t>{32, 32}, // minimal aligned square | ||
| std::pair<size_t, size_t>{128, 128}, // square | ||
| std::pair<size_t, size_t>{256, 64}, // many rows, narrow hidden | ||
| std::pair<size_t, size_t>{1024, 2048}, // large FFN-ish width | ||
| std::pair<size_t, size_t>{1, 1}, // single element | ||
| std::pair<size_t, size_t>{1, 96}, // single row | ||
| std::pair<size_t, size_t>{96, 1}, // single hidden column | ||
| std::pair<size_t, size_t>{13, 100}), // non-power-of-two | ||
| ::testing::Values(0, 32), // contiguous + interleaved | ||
| ::testing::Values(false, true)), // grad_act_scales off / on | ||
| test_name_generator); | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
gated_unscaledcallgated_unscaledcomputesunscaledon line 170, butgated_gradsunconditionally writes*unscaledon line 171, overwriting it. The first call is dead code — everygated_gradscase sets*unscaledbefore returning, so the result ofgated_unscaledis never observed. This should simply be removed.