Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion fast_llm/engine/config_utils/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,13 @@ class ParameterConfig(ModuleConfig):
" Combines multiplicatively with the scale set by the parent layer, if applicable.",
hint=FieldHint.feature,
)
weight_decay: float | bool | None = Field(
default=None,
desc="Override the default weight decay for this parameter."
" Set to `True` or `False` to enable or disable the optimizer weight decay,"
" or to a number to use a specific value.",
hint=FieldHint.feature,
)
# TODO: Initialization, lr_scale

def get_parameter(
Expand All @@ -67,7 +74,7 @@ def get_parameter(
dims,
init_method=default_initialization if self.initialization.is_default else self.initialization,
lr_scale=combine_lr_scales(lr_scale, self.lr_scale),
weight_decay=weight_decay,
weight_decay=weight_decay if self.weight_decay is None else self.weight_decay,
allow_sequence_tensor_parallel=allow_sequence_tensor_parallel,
)
if peft is not None:
Expand Down
36 changes: 22 additions & 14 deletions fast_llm/engine/multi_stage/stage_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,14 @@ def get_param_groups(
end = fsdp.index_buffer_to_shard(buffer_begin + (lr_scale_index + 1) * chunk_size)
if lr_scale == 0 or begin == end:
continue
optimizer_params = (parameter_meta.param_weight_decay, lr_scale)
# Resolve to the optimizer weight decay so the group key is unambiguous:
# `True` -> `None` (optimizer default), `False` -> `0.0`, a float -> itself.
# Keying on the raw value would merge `True` with a literal `1.0`.
weight_decay = parameter_meta.param_weight_decay
group_weight_decay = (
(None if weight_decay else 0.0) if isinstance(weight_decay, bool) else weight_decay
)
optimizer_params = (group_weight_decay, lr_scale)
if optimizer_params in grouped_parameter_slices:
last_slice = grouped_parameter_slices[optimizer_params][-1]
if begin == last_slice.stop:
Expand All @@ -257,17 +264,17 @@ def get_param_groups(

param_groups += [
param_group_cls(
name=f"wd_{weight_decay}_lr_scale_{lr_scale}", # noqa
name=f"wd_{group_weight_decay}_lr_scale_{lr_scale}", # noqa
params=[fsdp.weight_shard[slice_] for slice_ in slices], # noqa
grads=[fsdp.grad_shard[slice_] for slice_ in slices], # noqa
**{ # noqa
name: [optimizer_state[i][slice_] for slice_ in slices]
for name, optimizer_state in optimizer_state_shards.items()
},
weight_decay=None if weight_decay else 0.0, # noqa
weight_decay=group_weight_decay, # noqa
lr_scale=lr_scale, # noqa
)
for (weight_decay, lr_scale), slices in grouped_parameter_slices.items()
for (group_weight_decay, lr_scale), slices in grouped_parameter_slices.items()
]

# Get the weight slices to use for grad norm computation, merging consecutive slices.
Expand Down Expand Up @@ -343,14 +350,15 @@ def _get_parameter_metas(self) -> tuple[list[ParameterMeta], list[ParameterMeta]

@classmethod
def _reorder_parameter_metas(cls, parameter_metas):
reorder_index = sorted(
range(len(parameter_metas)),
key=lambda i: (
parameter_metas[i].param_weight_decay,
parameter_metas[i].param_weight_decay == parameter_metas[i].is_tensor_parallel,
parameter_metas[i].param_weight_decay != parameter_metas[i].sequence_tensor_parallel,
),
)
reordered_metas = [parameter_metas[i] for i in reorder_index]
def _sort_key(i):
# `param_weight_decay` may be a float; reduce it to a bool so the ordering (and the
# sequence-parallel contiguity it guarantees) follows the decayed/not-decayed split.
weight_decay = bool(parameter_metas[i].param_weight_decay)
return (
weight_decay,
weight_decay == parameter_metas[i].is_tensor_parallel,
weight_decay != parameter_metas[i].sequence_tensor_parallel,
)

return reordered_metas
reorder_index = sorted(range(len(parameter_metas)), key=_sort_key)
return [parameter_metas[i] for i in reorder_index]
2 changes: 2 additions & 0 deletions fast_llm/layers/common/linear/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def get_layer(
bias = self.bias.get_parameter(
(out_dim,),
default_initialization=default_bias_initialization,
weight_decay=False,
lr_scale=lr_scale,
default_enabled=default_add_bias,
peft=None,
Expand Down Expand Up @@ -222,6 +223,7 @@ def get_layer(
bias = self.bias.get_parameter(
(in_dim,),
default_initialization=default_bias_initialization,
weight_decay=False,
lr_scale=lr_scale,
default_enabled=default_add_bias,
peft=peft,
Expand Down
3 changes: 3 additions & 0 deletions fast_llm/layers/common/normalization/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,12 +216,14 @@ def __init__(self, config: ConfigType, hidden_dim: TensorDim, lr_scale: float |
self.weight = self._config.weight.get_parameter(
(hidden_dim,),
default_initialization=init_zeros_ if self._config.zero_centered else init_ones_,
weight_decay=False,
lr_scale=self._lr_scale,
peft=None,
)
self.bias = self._config.bias.get_parameter(
(hidden_dim,),
default_initialization=init_zeros_,
weight_decay=False,
lr_scale=self._lr_scale,
peft=None,
)
Expand Down Expand Up @@ -282,6 +284,7 @@ def __init__(self, config: ConfigType, hidden_dim: TensorDim, lr_scale: float |
self.weight = self._config.weight.get_parameter(
(hidden_dim,),
default_initialization=init_zeros_ if self._config.zero_centered else init_ones_,
weight_decay=False,
lr_scale=self._lr_scale,
peft=None,
)
Expand Down
1 change: 1 addition & 0 deletions fast_llm/layers/decoder/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def __init__(
self.output_scale = self._config.output_scale.get_parameter(
(scalar_dim,),
default_initialization=init_ones_,
weight_decay=False,
lr_scale=self._lr_scale,
peft=self._peft,
)
Expand Down
2 changes: 2 additions & 0 deletions fast_llm/layers/decoder/mlp/mixture_of_experts.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,14 @@ def __init__(
self.router_scale = self._config.router_scale.get_parameter(
(self._hidden_dim,),
default_initialization=init_ones_,
weight_decay=False,
lr_scale=self._lr_scale,
peft=self._peft,
)
self.router_per_expert_scale = self._config.router_per_expert_scale.get_parameter(
(TensorDim("experts", self._config.experts),),
default_initialization=init_ones_,
weight_decay=False,
lr_scale=self._lr_scale,
peft=self._peft,
)
Expand Down
2 changes: 2 additions & 0 deletions fast_llm/layers/ssm/gdn.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,7 @@ def __init__(
self.dt_bias: ParameterMeta = self._config.dt_bias_weight.get_parameter(
(self._value_heads_dim,),
default_initialization=init_ones_,
weight_decay=False,
lr_scale=self._lr_scale,
peft=self._peft,
)
Expand All @@ -310,6 +311,7 @@ def __init__(
default_initialization=LambdaInitializer(
lambda _, tensor, generator: tensor.uniform_(0, 16, generator=generator).log_()
),
weight_decay=False,
lr_scale=self._lr_scale,
peft=self._peft,
)
Expand Down
2 changes: 2 additions & 0 deletions fast_llm/layers/ssm/kda.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,7 @@ def __init__(
self.dt_bias: ParameterMeta = self._config.dt_bias_weight.get_parameter(
(self._projection_dim,),
default_initialization=init_ones_,
weight_decay=False,
lr_scale=self._lr_scale,
peft=self._peft,
)
Expand All @@ -351,6 +352,7 @@ def __init__(
default_initialization=LambdaInitializer(
lambda _, tensor, generator: tensor.uniform_(1, 16, generator=generator).log_()
),
weight_decay=False,
lr_scale=self._lr_scale,
peft=self._peft,
)
Expand Down
4 changes: 2 additions & 2 deletions fast_llm/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def __init__(
tensor_name: str = "",
dims: tuple[TensorDim, ...],
init_method: "Initialization | typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], None] | None" = None,
weight_decay: bool = True,
weight_decay: float | bool = True,
# Pass a list to split the parameter in contiguous (dim=0) chunks of equal size for optimization.
lr_scale: float | None | tuple[float | None, ...] = None,
requires_grad: bool = True,
Expand Down Expand Up @@ -285,7 +285,7 @@ def __new__(
tensor_name: str = "",
dims: tuple[TensorDim, ...],
init_method: "Initializer | typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], None] | None",
weight_decay: bool = True,
weight_decay: float | bool = True,
lr_scale: float | None | tuple[float | None, ...] = None,
allow_sequence_tensor_parallel: bool = True,
allow_no_grad: bool = False,
Expand Down
Loading