diff --git a/examples/batch_size_scaling/ANALYSIS.md b/examples/batch_size_scaling/ANALYSIS.md new file mode 100644 index 000000000..9dba7362f --- /dev/null +++ b/examples/batch_size_scaling/ANALYSIS.md @@ -0,0 +1,239 @@ +# Batch-size scaling for Adam — analysis + +Companion to `README.md` (which gives the run matrix). This note records *why* the experiment +is designed the way it is: what scaling of Adam's hyperparameters with batch size is correct, why +the naive intuitions fail, and what we can claim with a guarantee versus only empirically. + +## Motivation + +The folklore is that large batches are necessary for good training. The gradient-noise-scale theory +(McCandlish et al., [arXiv:1812.06162](https://arxiv.org/abs/1812.06162)) already says otherwise: +there is a *critical batch size* below which adding batch buys almost nothing. Large batch is also +actively painful in practice — in RL it forces stale rollouts and heavy gradient accumulation. + +So: can small-batch training match large-batch if Adam's hyperparameters are scaled correctly — and +is there a scaling rule we can trust *before* tuning, in regimes we can't cheaply sweep (late +training, RL)? We want an **equivalence** rule with a generalization guarantee, not just a setting +that happens to be optimal in one regime. + +## Setup + +Adam's per-coordinate update is **normalized**: + +``` +θ ← θ − lr · m̂ / (√v̂ + ε) +m = β1·m + (1−β1)·g (first moment → smoothed gradient, the "signal") +v = β2·v + (1−β2)·g² (second moment → recent mean-square, per-coordinate scale) +``` + +(`m̂`, `v̂` are bias-corrected.) Write the minibatch gradient as `g = ⟨g⟩ + n`, signal plus +zero-mean noise, with batch-noise variance `⟨n²⟩ = σ²/b` for batch size `b`. + +The key fact is that `m̂/√v̂ ≈ O(1)` per coordinate (the gradient magnitude is divided out), so the +**step size is set by `lr` itself**, not by `lr × gradient`. + +## Why the SGD linear-scaling rule does **not** transfer to Adam + +For SGD the update is `−lr·g` — *linear* in the gradient. Scaling `lr ∝ b` and holding the momentum +timescale fixed makes B/b small steps sum to one large step (the noise averages out across the sum). +This is the classic linear scaling rule. + +Adam breaks this because of the `√v̂` normalization. Two ways to see it: + +- **Per-step magnitude (noise-dominated).** `m̂ → ⟨g⟩` (noise averages out of the first moment), + `√v̂ → √(σ²/b)`, so the update `≈ lr·⟨g⟩/√(σ²/b) = lr·⟨g⟩·√b/σ ∝ lr·√b`. The batch enters as + **√b**, not b — so a *linear* lr correction over-corrects. + +- **"32 small steps ≠ one large step."** Each small step is `lr·m̂/√v̂`, and that factor is itself + smaller for small batches: `m̂/√v̂ = ⟨g⟩/√(⟨g⟩²+σ²/b)`, whose denominator carries a noise floor + `σ²/b` that is **larger** for small batches and is *not* removed by averaging longer (averaging + smooths the estimate of `v̂`, not its mean). So `32` steps of `lr/32` sum to *less* than one large + step — by a factor `√[(⟨g⟩²+σ²/B)/(⟨g⟩²+σ²/b)]`, which in the noise-dominated limit is `≈ 1/√(B/b)`. + (In the *low-noise* limit the factor → 1 and the SGD intuition would hold; real small-batch training + is in the noisy regime where it doesn't.) + +This is exactly why a linearly-lr-scaled small-batch arm undertrains relative to large batch, while +keeping lr fixed at small batch *beats* large batch (more full-size — if noise-damped — steps per token). + +## The square-root (SDE) scaling rule + +The principled rule comes from the SDE view. Malladi et al., +[*On the SDEs and Scaling Rules for Adaptive Gradient Algorithms*](https://arxiv.org/abs/2205.10287) +(NeurIPS 2022), derive an SDE for Adam/RMSprop and prove the discrete updates are a **1st-order weak +approximation** of it. The scaling rule that keeps all batch sizes as 1st-order weak approximations of +the **same SDE** — i.e. makes the trajectories provably coincide to that order — is, for batch scaled +by a factor δ: + +``` +lr → lr · √δ +β1 → 1 − δ·(1 − β1) +β2 → 1 − δ·(1 − β2) +``` + +So learning rate scales as **√batch** (not linearly), and the EMA timescales are held fixed in +samples. This is the conservative "scale everything" instinct, corrected: the momentum scaling is +right, but the lr correction is √, not linear. + +## Equivalence vs optimality — and why equivalence is what generalizes + +There are two different questions, with different answers: + +- **What lr is *optimal* at each batch size?** Empirically ~batch-*insensitive* for Adam — a very weak + positive dependence (Marek et al., [arXiv:2507.07101](https://arxiv.org/abs/2507.07101), report the + optimal lr moving only ~3× over a 1024× batch change). So "keep lr, scale only β2" is roughly optimal, + and small batch then *beats* large batch on tokens-per-step efficiency. + +- **What scaling makes small batch *reproduce* large batch?** The √ rule above (SDE equivalence). + +The optimality answer is empirical and regime-specific — there's no guarantee it survives to late +training or transfers to RL. The **equivalence** answer carries a guarantee that is a property of the +*optimizer's response to gradient noise*, independent of the loss landscape: if small batch provably +reproduces the large-batch trajectory, then anything validated at large batch transfers. That is the +property we want "before optimizing." + +**Caveats on the guarantee (honest scope):** +- It is a *1st-order weak* / small-lr SDE result; it degrades when lr is large enough that + discretization error dominates. +- It guarantees equivalence *to large batch* — it inherits large batch's behavior, including any + suboptimality. (Reproducing large batch is the point: safe and transferable, not maximal.) +- It assumes the gradient-noise / diffusion regime (below the critical batch size). +- For **RL**: the guarantee covers the optimizer dynamics (the task-agnostic part). It does **not** + cover RL's non-stationarity (policy shift, off-policy staleness) — that is outside any batch-scaling + SDE. But it is still the most defensible transfer basis, and it is tuning-free. + +## Two effects that distinguish small from large batch even under the rule + +These are the reasons we expect the √-rule small-batch run to be *close to but not identical* to large +batch, and they are hypotheses the experiment tests (not established results): + +- **Sub-step freshness / staleness (favours small batch, slightly).** The SDE equivalence is only + 1st-order; the leading finite-step correction favours more, smaller steps, because each sub-step uses + a gradient at a *fresher* iterate (less discretization "self-staleness"). A large step computes one + gradient at `θ_t`; the equivalent small-batch sequence computes gradients at `θ_t, θ_{t+1}, …`. In RL + this is sharper: staleness is a fixed ~1 generation regardless of step size, so taking many small + sub-steps amortizes that fixed cost — small batch is structurally better-positioned on staleness. + +- **Constant-lr noise floor (the "head start may not hold" worry).** With constant lr, training + plateaus at a noise floor set by the SDE's stationary temperature, which grows with the *effective* + lr. Keeping full lr at small batch (the "optimal"/paper setting) raises the temperature → a **higher + floor**; the √-rule reproduces large batch's temperature → its floor. So full-lr small batch is + expected to descend fastest early but settle higher, while the √-rule run tracks large batch and can + overtake it late. (Counter-consideration: at leading order the displacement per token `∝ lr/√b` + persists late, so whether the extra motion is progress or just noise around a higher floor is the + empirical question — which is why the runs must go long.) + +- **Higher-order correction.** Carrying the next term, `m̂/√v̂ = ⟨g⟩/√(⟨g⟩²+σ²/b) ≈ + (⟨g⟩/√(σ²/b))·(1 − ⟨g⟩²/2·b/σ²·…)`, gives an SNR-dependent correction (not a power law) that is + largest when signal ≈ noise (early/mid training) and vanishes as noise dominates (late). So the + √-rule is the asymptotically exact part; deviations live in the transient. + +## Experiment + +All arms branch from one lightly-pretrained checkpoint and share one fresh, disjoint token stream +(paired comparison); loss is compared vs **tokens seen** (`training/consumed_tokens`), since optimizer +steps are not comparable across batch sizes. Small batch b=16, large B=512 (ratio 32). + +| Arm | batch | lr | β1 | β2 | what it isolates | +|---|---|---|---|---|---| +| A | 512 | 1e-4 | 0.9 | 0.95 | large-batch reference | +| B | 16 | 1e-4 | 0.9 | 0.95^(1/32) | keep-lr, β2 scaled (≈ empirical optimum) | +| D | 16 | 1e-4 | 0.9 | 0.95 | nothing scaled (naive) | +| G | 16 | 1e-4/√32 | 0.9 | 0.95 | √-lr only, **no** β scaling | +| I | 16 | 1e-4/√32 | 0.9 | 0.95^(1/32) | √-lr + β2 scaled | +| H | 16 | 1e-4/√32 | 0.9^(1/32) | 0.95^(1/32) | **√-lr + β1&β2 = full SDE rule** | + +(Plus precision twins of A at b=512: TF32 and fp16.) G→I→H is a momentum-scaling ladder at the correct +(√) lr; H is the theoretically-guaranteed rule. + +**Predictions.** H ≈ A across training (the equivalence guarantee), perhaps a hair better +(sub-step freshness). B fastest early but may settle to a higher floor late. D and the linear-lr +scalings undertrain. Precision twins overlap A (precision is not the axis that matters here). + +**The two curves that decide it:** (1) does H track A *into late training*? (2) does B's early lead +over A/H *erode* late? The first is the equivalence guarantee; the second is whether the empirically- +optimal "keep lr" is a transient head start or a durable win. + +## Results (preliminary — runs ongoing) + +The experiment ran in two rounds, and the decisive lesson is about *regime*: **the √-rule holds in +the noise-dominated regime it is derived for, and that regime is only reached well into training.** + +### Round 1 — signal-dominated, outside the rule's domain + +Branching the arms early (loss ~5 → ~3 over the run) put them in a **signal-dominated** regime: +per-step gradients are dominated by signal, not noise. Three things followed, all observed: + +- The √-rule is not even *expected* to apply — it is a diffusion / noise-dominated approximation. +- The noise-control knobs (batch beyond the critical batch, the β EMA windows, precision) had little + leverage and washed out. +- The dominant effect was just the learning rate: the large-batch reference at 1e-4 was visibly + *under*-tuned (a b=512 lr sweep 1e-4 → 3e-4 → 5.66e-4 closed most of the gap to the small-batch arms). + +So round 1's batch-size differences were an update-count / drift effect, not the noise-averaging the +√-rule addresses — it was not a valid test of the equivalence. This motivated round 2. + +### Round 2 — noise-dominated, the rule's domain + +Re-anchored deep in training: arms branched (weights only, cold optimizer) from a checkpoint at +~18B tokens / loss ~2.8, batch ratio 16 (b=32 ↔ b=512 = 65,536 ↔ 1,048,576 tokens), keeping the 1e-4 +baseline. Here the model is **noise-dominated** — measured per-step loss std is **0.027 at b=32 vs +0.008 at b=512, ratio ≈ 3.4** (close to the √16 = 4 that the 1/√batch noise law predicts), and this +per-step noise (~0.01–0.03) dwarfs the systematic inter-arm differences (~0.005) by an order of +magnitude. The std depends only on batch size, not on the optimizer settings. + +In this regime the √/SDE rule reproduces the large-batch trajectory. Comparing at matched tokens +(loss-vs-tokens, binned over ~1B-token windows), the two √-rule pairs overlay: + +| operating point | b=512 arm | b=32 (√-scaled) arm | gap | +|---|---|---|---| +| baseline (lr 1e-4) | A | H (lr 2.5e-5, β1 & β2 scaled) | ~0.0002 | +| aggressive (b=32 keep-lr) | J (lr 4e-4, √-up of B) | B (lr 1e-4, β2 scaled) | ~0.0006 | + +Both gaps are ~10× smaller than the ~0.005 spread *between* the conservative (A/H) and aggressive +(B/J) operating points — i.e. the √-rule maps each batch size onto the other's trajectory at both +operating points. + +**Secondary signals** — small (~0.0005–0.0008 nat) but directionally consistent at every bin, so +credible though not yet conclusive: + +- **β1 scaling helps.** The full-SDE arms (both β scaled) sit slightly below the β2-only arms at both + lr settings; dropping β1 makes the √-rule arm slightly *under*-track the baseline. A (mild) point for + the full SDE rule over the paper "β2 only" rule. +- **fp16 edges bf16** by ~0.0005 (plausibly fp16's 10 mantissa bits vs bf16's 7) — refining "precision + doesn't matter" to "precision barely matters, in the direction of more mantissa bits." +- **Batch size matters far less in round 2 than round 1** once scaled — consistent with the equivalence + holding only in the noise-dominated regime. + +**Regime is the through-line.** Round 1 (signal-dominated) is outside the rule's domain and the noise +knobs are inert; round 2 (noise-dominated) is inside it, the rule holds, and the noise knobs acquire +leverage. The knobs lighting up only in round 2 is itself a fingerprint of the regime. This is the +point for the motivation: **RL and late-stage training are noise-dominated — the same regime as round +2** — so "equivalence holds in round 2" is the evidence it transfers to the settings we care about. + +### Methodology note: don't reuse the data seed when branching + +Round 2's arms showed a sharp, persistent loss step at *exactly* 17.69B tokens. It is **not** a +data-quality artifact: the document shuffle is seeded independently of batch size / num_samples, so a +child run with the same seed replays the parent's document order. The arms branched from a checkpoint +trained to 17.69B tokens, so for the first 17.69B tokens they re-traverse data the parent already fit +(low loss), then jump to the true level on reaching novel data. **Lesson:** when branching from a +checkpoint, use a different data seed (or skip the parent's consumed prefix). All comparisons above are +windowed to the post-17.69B novel-data region. + +### Status & caveats + +- Runs are **ongoing and not converged**; numbers are preliminary and expected to firm up. +- Comparisons use **training loss** — but on a **single epoch** this *is* effectively a held-out + measure: each batch's loss is computed before the model trains on it, and (post-revisit) on + never-before-seen data, so there is no real train/validation distinction. The broken eval-loss + logging ([#538](https://github.com/ServiceNow/Fast-LLM/issues/538)) is therefore not a real + limitation here. The one exception is the 17.69B re-read prefix (data the branch parent already + saw) — memorization-contaminated, and excluded from all comparisons. +- Single 0.5B model, one dataset; the secondary signals are ~1.5σ per bin individually, credible only + via cross-bin directional consistency. + +## References + +- McCandlish, Kaplan, Amodei et al., *An Empirical Model of Large-Batch Training*, [arXiv:1812.06162](https://arxiv.org/abs/1812.06162) — critical batch size / gradient noise scale. +- Malladi, Lyu, Panigrahi, Arora, *On the SDEs and Scaling Rules for Adaptive Gradient Algorithms*, [arXiv:2205.10287](https://arxiv.org/abs/2205.10287) (NeurIPS 2022) — square-root scaling rule with an SDE-equivalence guarantee. +- Marek et al., *(batch-size hyperparameter scaling)*, [arXiv:2507.07101](https://arxiv.org/abs/2507.07101) — empirical: optimal Adam lr is ~batch-insensitive; β2 scales to hold the token half-life. diff --git a/examples/batch_size_scaling/README.md b/examples/batch_size_scaling/README.md new file mode 100644 index 000000000..fca49816a --- /dev/null +++ b/examples/batch_size_scaling/README.md @@ -0,0 +1,60 @@ +# Batch-size scaling experiment (Qwen2.5-0.5B) + +Does small-batch Adam training **reproduce** large-batch training when the hyperparameters are scaled +by the square-root (SDE) rule, and how does that compare to the "keep lr, scale β2" paper rule? See +[`ANALYSIS.md`](ANALYSIS.md) for the theory, the predictions, and the results. + +Loss is compared **vs tokens seen** (`training/consumed_tokens`) — optimizer steps are not comparable +across batch sizes. All arms share one model init (the warmup checkpoint) and the same shuffled token +stream (the `experiment` split), so the curves form a paired comparison. + +## Reproduce + +```bash +# 0. Tokenize FineWeb-Edu into disjoint warmup / experiment / validation splits. +fast-llm prepare gpt_memmap --config prepare.yaml + +# 1. Throwaway warmup from scratch (~1B tokens); note the final checkpoint iteration. +torchrun --nproc_per_node=4 -m fast_llm.cli train gpt --config warmup.yaml + +# 2. Branch each arm from the warmup checkpoint (weights only, cold optimizer) via `pretrained`, +# one run per arm. arm_base.yaml holds the shared settings; arms differ only in +# schedule.depth_first_micro_batches (batch), optimizer.learning_rate.base, and the betas. +fast-llm train gpt --config arm_base.yaml \ + run.experiment_dir=experiments/batch_size/ \ + schedule.depth_first_micro_batches= \ + optimizer.learning_rate.base= optimizer.beta_1= optimizer.beta_2= +``` + +Set `pretrained.path` to the warmup checkpoint and the W&B `entity_name` / `project_name` in the +configs to your own. Batch (tokens) = `depth_first_micro_batches × micro_batch_size × data_parallel`; +`b` below is in sequences of 2048. + +## Arms + +Two √-rule pairs anchor the comparison — scale **down** from the large-batch baseline (A → H) and **up** +from the small-batch keep-lr arm (B → J). With batch ratio `r`: + +| arm | batch | lr | β1 | β2 | what it isolates | +|---|---|---|---|---|---| +| A | large | 1e-4 | 0.9 | 0.95 | large-batch baseline | +| H | small | 1e-4/√r | 0.9^(1/r) | 0.95^(1/r) | **full √/SDE rule** (= A scaled down) | +| I | small | 1e-4/√r | 0.9 | 0.95^(1/r) | √-lr + β2 only (drops β1) | +| B | small | 1e-4 | 0.9 | 0.95^(1/r) | keep-lr + β2 (paper rule) | +| M | small | 1e-4 | 0.9^(1/r) | 0.95^(1/r) | keep-lr + both β | +| J | large | 1e-4·√r | 0.9 | 0.95 | large-batch √-up image of B | +| K | large | (lr sweep) | 0.9 | 0.95 | large-batch lr sweep | +| L | large | 1e-4 | 0.9 | 0.95 | fp16 twin of A (precision) | + +All arms: constant lr (no decay), `weight_decay=0`, `epsilon=1e-8`, gradient-norm clip 5.0, shared +seed. `num_samples` (= `depth_first_micro_batches × train_iters`) is matched across arms so the shuffled +stream is identical — only the batching differs. + +Two rounds were run (details and results in [`ANALYSIS.md`](ANALYSIS.md)): round 1 at `r=32` early in +training (**signal-dominated** — the wrong regime, the rule's knobs wash out), and round 2 at `r=16` +deep in training (**noise-dominated** — where the rule holds and the √-pairs overlay). + +> **Gotcha — don't reuse the data seed when branching.** The document shuffle is seeded independently +> of batch size / `num_samples`, so a child run with the same seed *replays the parent's already-seen +> data* until it reaches the parent's token horizon (a sharp, persistent loss step). Use a different +> seed, or skip the parent's consumed prefix. See the methodology note in `ANALYSIS.md`. diff --git a/examples/batch_size_scaling/arm_base.yaml b/examples/batch_size_scaling/arm_base.yaml new file mode 100644 index 000000000..f3e2824bf --- /dev/null +++ b/examples/batch_size_scaling/arm_base.yaml @@ -0,0 +1,62 @@ +# Stage 2 — branch arms from the warmup checkpoint. Each arm is a SEPARATE run, launched with +# per-arm overrides (see README.md for the arm matrix). This base file holds the shared settings +# and defaults to a naive unscaled small-batch configuration; every arm overrides from here. +# +# Arms inherit the architecture from the warmup checkpoint (pretrained.load_config: model), +# start fresh (completed_steps = 0) so all arms consume the SAME `experiment` token stream from +# offset 0, and differ only in: schedule.depth_first_micro_batches, optimizer betas/lr, +# training.train_iters (chosen so num_samples is identical across arms), eval interval, and dir. +# +# fast-llm train gpt --config arm_base.yaml +pretrained: + path: experiments/batch_size/warmup/checkpoint/LATEST # set to the ~1h checkpoint iteration + format: distributed + model_weights: true + optimizer_state: false # cold optimizer; each arm warms its own v under its own beta_2 + load_config: model +training: + train_iters: 1280000 # small-batch cap; num_samples = 2 x 1280000 = 2,560,000 (~42B tokens at micro_batch_size 16384, within the ~94B experiment split). Monitor and stop. + num_workers: 8 + wandb: + entity_name: null # set to your W&B entity to enable logging + project_name: batch_size_scaling + group_name: arms + logs: + interval: 10 + checkpoint: + interval: 30000 # ~1B tokens for small arms + keep: 2 + evaluators: + validation: + type: loss + interval: 6000 # ~200M-token cadence for small arms + iterations: 50 +schedule: + depth_first_micro_batches: 2 # small batch = 2 x 16384 = 32768 tokens (large arms override to 64 -> 1,048,576) +data: + datasets: + training: + type: file + path: data/batch_size/fineweb_edu_qwen/fast_llm_config_experiment.yaml + validation: + type: file + path: data/batch_size/fineweb_edu_qwen/fast_llm_config_validation.yaml + micro_batch_size: 16384 # tokens packed per micro-batch (throughput knob; documents are masked per-doc within the pack). Same value for all arms -> identical packing. +optimizer: + learning_rate: + base: 1.0e-04 + decay_style: constant + warmup_iterations: 0 + weight_decay: 0.0 + beta_1: 0.9 + beta_2: 0.95 + epsilon: 1.0e-08 + gradient_norm_clipping: 5.0 # raised from 1.0: grad norms ~1, so clipping ~never fires -> avoids batch-size-dependent differential clipping; still a safety net vs spikes +model: + multi_stage: + zero_stage: 1 # ZeRO-1 (optimizer-state sharding); weights replicated across data-parallel ranks + distributed: + compute_dtype: bf16 + seed: 7 # shared across arms -> identical data order +run: + experiment_dir: experiments/batch_size/arm_D diff --git a/examples/batch_size_scaling/prepare.yaml b/examples/batch_size_scaling/prepare.yaml new file mode 100644 index 000000000..cd8dc3d39 --- /dev/null +++ b/examples/batch_size_scaling/prepare.yaml @@ -0,0 +1,22 @@ +# Stage 0 — tokenize FineWeb-Edu with the Qwen2.5 tokenizer into a Fast-LLM memmap dataset, +# partitioned into three disjoint splits: +# warmup — the throwaway from-scratch prefix (stage 1) +# experiment — fresh, unseen-by-warmup data, shared by every arm (stage 2) +# validation — held out, shared by all runs for comparable eval-loss curves +# +# fast-llm prepare gpt_memmap --config examples/batch_size_scaling/prepare.yaml +output_path: data/batch_size/fineweb_edu_qwen +dataset: + path: HuggingFaceFW/fineweb-edu + config_name: sample-100BT + split: train + trust_remote_code: true +tokenizer: + path: Qwen/Qwen2.5-0.5B +# Qwen has no native BOS; skip it and separate documents by EOS (<|endoftext|>) only (PR #534). +add_bos: false +add_eos: true +splits: + warmup: 0.05 + experiment: 0.94 + validation: 0.01 diff --git a/examples/batch_size_scaling/results/round1_all_arms.png b/examples/batch_size_scaling/results/round1_all_arms.png new file mode 100644 index 000000000..280635bb9 Binary files /dev/null and b/examples/batch_size_scaling/results/round1_all_arms.png differ diff --git a/examples/batch_size_scaling/results/round2_all_arms.png b/examples/batch_size_scaling/results/round2_all_arms.png new file mode 100644 index 000000000..c982c3d0f Binary files /dev/null and b/examples/batch_size_scaling/results/round2_all_arms.png differ diff --git a/examples/batch_size_scaling/results/round2_equivalence.png b/examples/batch_size_scaling/results/round2_equivalence.png new file mode 100644 index 000000000..1145d844f Binary files /dev/null and b/examples/batch_size_scaling/results/round2_equivalence.png differ diff --git a/examples/batch_size_scaling/warmup.yaml b/examples/batch_size_scaling/warmup.yaml new file mode 100644 index 000000000..1affa4f20 --- /dev/null +++ b/examples/batch_size_scaling/warmup.yaml @@ -0,0 +1,81 @@ +# Stage 1 — from-scratch warmup of a Qwen2.5-0.5B model to a mid-early-training branch point. +# This prefix is throwaway: only the final checkpoint is reused, as `pretrained` init for the arms. +# Run data-parallel on 4 GPUs; stop at ~1h wall-clock and branch from the last checkpoint. +# torchrun --nproc_per_node=4 -m fast_llm.cli train gpt --config examples/batch_size_scaling/warmup.yaml +# +# Global batch = micro_batch_size (2048) x depth_first_micro_batches (128) x data_parallel (4) +# = 512 sequences of length 2048 (matches Qwen2.5's pretraining batch). +training: + train_iters: 1000 # cap only — kill at ~1h and use the latest checkpoint + num_workers: 8 + wandb: + entity_name: null # set to your W&B entity to enable logging + project_name: batch_size_scaling + group_name: warmup + logs: + interval: 10 + checkpoint: + interval: 100 # frequent saves so we can branch near the 1h mark + keep: 3 +schedule: + depth_first_micro_batches: 128 +data: + datasets: + training: + type: file + path: data/batch_size/fineweb_edu_qwen/fast_llm_config_warmup.yaml + micro_batch_size: 2048 +optimizer: + learning_rate: + base: 1.0e-04 + decay_style: constant + warmup_iterations: 100 + weight_decay: 0.0 + beta_1: 0.9 + beta_2: 0.95 + epsilon: 1.0e-08 + gradient_norm_clipping: 1.0 +model: + base_model: + embeddings: + vocab_size: 151936 + dropout: 0.0 + decoder: + block: + mixer: + type: attention + rotary: + type: default + theta: 1000000 + heads: 14 + head_groups: 2 + head_size: 64 + add_linear_biases: false + query_layer: {bias: {enabled: true}} + key_layer: {bias: {enabled: true}} + value_layer: {bias: {enabled: true}} + dense_layer: {bias: {enabled: false}} + dropout: 0.0 + mlp: + intermediate_size: 4864 + add_linear_biases: false + gated: true + activation: silu + normalization: + type: rms_norm + epsilon: 1.0e-06 + dropout: 0.0 + num_blocks: 24 + head: + normalization: + type: rms_norm + epsilon: 1.0e-06 + hidden_size: 896 + tied_embedding_weight: true + multi_stage: + zero_stage: 2 + distributed: + compute_dtype: bf16 + seed: 1234 +run: + experiment_dir: experiments/batch_size/warmup