Skip to content

Rewrite Triton normalization backward kernel_1 (#499)#546

Open
jlamypoirier wants to merge 1 commit into
mainfrom
jlp_norm_kernel1_rewrite
Open

Rewrite Triton normalization backward kernel_1 (#499)#546
jlamypoirier wants to merge 1 commit into
mainfrom
jlp_norm_kernel1_rewrite

Conversation

@jlamypoirier

Copy link
Copy Markdown
Collaborator

Summary

Closes the backward-pass gap in layer_norm/rms_norm (issue #499). On H100 the Triton backward trailed apex and torch-compiled by ~1.1–1.6× at most hidden sizes — worst on tall-narrow shapes. kernel_1 was both the bottleneck and the source of an oversized partial-reduction buffer for kernel_2.

Two changes to kernel_1 (fast_llm/functional/triton/normalization.py):

  1. Decouple the register tile from n_cols. A block_size_row × block_size_col tile grid-strides the columns, so occupancy no longer collapses as the hidden size grows. Rows wider than one chunk use a two-pass scheme (reduce the per-row corrections, then re-read to write grad_input and the partials); narrower rows stay single-pass with no re-read.
  2. Bound the partial-reduction work, the way apex does. apex (both the general fused path and the hand-tuned fast path) does not avoid a second reduction kernel — it bounds the number of partial rows to a small constant via row grid-striding. Previously kernel_1 emitted one partial row per block_size_row input rows, so the buffer kernel_2 reduces grew with the row count (4096 rows at 32768×1024 → kernel_2 ran at ~10% of bandwidth). Single-pass now grid-strides the rows with a program count fixed at multi_processor_count × 2, folding many row tiles into one fp32-accumulated partial. The buffer is then independent of the row count. Two waves per SM is the measured knee — one wave starves grad_input latency-hiding; more only re-inflates kernel_2.

Parameter-grad partials reduce in fp32 (the store casts to the buffer dtype) — reducing in bf16 measurably degrades the parameter gradients.

Results (H100, bf16)

Backward µs vs. the fastest competitor (apex_fast for LN, torch-compiled-max otherwise). Bold = match-or-beat.

layer_norm ours best alt ratio rms_norm ours best alt ratio
8192×1024 27.9 30.2 0.92× 4096×4096 46.5 62.5 0.74×
1024×8192 39.4 47.7 0.83× 8192×1024 24.2 30.7 0.79×
2048×4096 32.2 35.8 0.90× 2048×4096 26.4 31.6 0.84×
16384×1024 49.1 50.7 0.97× 32768×1024 78.3 86.9 0.90×
4096×4096 56.0 56.9 0.98× 16384×1024 43.1 47.4 0.91×
32768×1024 89.4 88.6 1.01× 16384×2048 81.7 83.8 0.98×
8192×4096 99.3 83.4 1.19× 2048×16384 135 115 1.18×
512×16384 50.3 39.2 1.28× 512×16384 41.9 31.9 1.31×
  • The tall-narrow shapes that motivated the issue go from ~1.3–1.6× behind to parity-or-better, match-or-beating the fastest alternative on ~8/15 shapes per norm.
  • apex's general fused path is beaten across the board (it is 1.5–3× slower on backward).
  • kernel_2 is no longer the problem: 2–5 µs on single-pass shapes (was up to 47 µs).
  • Remaining gap: wide hidden sizes (n_cols ≥ 8192, two-pass) stay ~1.1–1.3× behind, bounded by the two-pass column re-read. This path is untouched here and is the natural follow-up.

Forward is at parity across implementations and is unchanged.

Benchmark harness

tools/benchmark/triton_kernels:

  • Isolated, cold-L2 backward timing (forward untimed, L2 flushed, then the backward timed) — training-representative. The prior fwd_bwd − fwd number had a warm-L2 confound: the forward left the saved output partly resident in L2, flattering the backward in a way real training never sees, which made the rewrite look like a regression on some mid shapes where it is actually at parity.
  • Per-kernel device-time breakdown, so kernel_1 and kernel_2 can be attributed separately.

Validation

tests/layers/ and tests/tools/test_triton_benchmark.py: 733 passed, 27 skipped (H100). Parameter-grad precision is bit-equivalent to the previous kernel (grad_weight rel-rms ≈ 2.8–2.9e-3).


Authored by Claude Opus 4.8 (Claude Code).

🤖 Generated with Claude Code

The backward of `layer_norm`/`rms_norm` trailed apex and torch-compiled by
1.1-1.6x at most hidden sizes, worst on tall-narrow shapes. kernel_1 was the
bottleneck and over-produced grad_weight/grad_bias partials.

kernel_1:
- Decouple the register tile from `n_cols`: a `block_size_row x block_size_col`
  tile grid-strides the columns, so occupancy no longer collapses as hidden size
  grows. Rows wider than one chunk use a two-pass scheme (reduce per-row
  corrections, then re-read to write grad_input and the partials); narrower rows
  stay single pass.
- Bound the partial-reduction work like apex: single pass grid-strides the rows
  with a program count fixed at `multi_processor_count x 2`, folding many row
  tiles into one fp32-accumulated partial. The partial buffer kernel_2 reduces is
  then independent of the row count instead of growing with it (e.g. 4096 -> ~260
  rows at 32768x1024), which was the dominant remaining cost. Two waves per SM is
  the measured knee: one starves grad_input latency-hiding, more only re-inflates
  kernel_2.

Parameter-grad partials reduce in fp32 (the store casts to the buffer dtype);
reducing in bf16 degraded the parameter gradients.

Result (H100, bf16): tall-narrow shapes go from ~1.3-1.6x behind to parity or
better against the fastest alternative (apex_fast / torch-compiled-max), and
apex's general fused path is beaten across the board. Wide hidden sizes
(two-pass) remain ~1.1-1.3x behind, bounded by the column re-read.

Benchmark harness (tools/benchmark/triton_kernels):
- Measure backward in isolation with a cold L2 (forward untimed, L2 flushed, then
  the backward timed), which is training-representative. The prior fwd_bwd-minus-fwd
  number had a warm-L2 confound: the forward left the saved output partly resident,
  flattering the backward in a way real training never sees.
- Add a per-kernel device-time breakdown so kernel_1 and kernel_2 can be attributed
  separately.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant