diff --git a/fast_llm/engine/config_utils/parameter.py b/fast_llm/engine/config_utils/parameter.py index 3e2b61120..083d5fa46 100644 --- a/fast_llm/engine/config_utils/parameter.py +++ b/fast_llm/engine/config_utils/parameter.py @@ -49,6 +49,13 @@ class ParameterConfig(ModuleConfig): " Combines multiplicatively with the scale set by the parent layer, if applicable.", hint=FieldHint.feature, ) + weight_decay: float | bool | None = Field( + default=None, + desc="Override the default weight decay for this parameter." + " Set to `True` or `False` to enable or disable the optimizer weight decay," + " or to a number to use a specific value.", + hint=FieldHint.feature, + ) # TODO: Initialization, lr_scale def get_parameter( @@ -67,7 +74,7 @@ def get_parameter( dims, init_method=default_initialization if self.initialization.is_default else self.initialization, lr_scale=combine_lr_scales(lr_scale, self.lr_scale), - weight_decay=weight_decay, + weight_decay=weight_decay if self.weight_decay is None else self.weight_decay, allow_sequence_tensor_parallel=allow_sequence_tensor_parallel, ) if peft is not None: diff --git a/fast_llm/engine/multi_stage/stage_base.py b/fast_llm/engine/multi_stage/stage_base.py index ea737524b..e696a8f83 100644 --- a/fast_llm/engine/multi_stage/stage_base.py +++ b/fast_llm/engine/multi_stage/stage_base.py @@ -245,7 +245,14 @@ def get_param_groups( end = fsdp.index_buffer_to_shard(buffer_begin + (lr_scale_index + 1) * chunk_size) if lr_scale == 0 or begin == end: continue - optimizer_params = (parameter_meta.param_weight_decay, lr_scale) + # Resolve to the optimizer weight decay so the group key is unambiguous: + # `True` -> `None` (optimizer default), `False` -> `0.0`, a float -> itself. + # Keying on the raw value would merge `True` with a literal `1.0`. + weight_decay = parameter_meta.param_weight_decay + group_weight_decay = ( + (None if weight_decay else 0.0) if isinstance(weight_decay, bool) else weight_decay + ) + optimizer_params = (group_weight_decay, lr_scale) if optimizer_params in grouped_parameter_slices: last_slice = grouped_parameter_slices[optimizer_params][-1] if begin == last_slice.stop: @@ -257,17 +264,17 @@ def get_param_groups( param_groups += [ param_group_cls( - name=f"wd_{weight_decay}_lr_scale_{lr_scale}", # noqa + name=f"wd_{group_weight_decay}_lr_scale_{lr_scale}", # noqa params=[fsdp.weight_shard[slice_] for slice_ in slices], # noqa grads=[fsdp.grad_shard[slice_] for slice_ in slices], # noqa **{ # noqa name: [optimizer_state[i][slice_] for slice_ in slices] for name, optimizer_state in optimizer_state_shards.items() }, - weight_decay=None if weight_decay else 0.0, # noqa + weight_decay=group_weight_decay, # noqa lr_scale=lr_scale, # noqa ) - for (weight_decay, lr_scale), slices in grouped_parameter_slices.items() + for (group_weight_decay, lr_scale), slices in grouped_parameter_slices.items() ] # Get the weight slices to use for grad norm computation, merging consecutive slices. @@ -343,14 +350,15 @@ def _get_parameter_metas(self) -> tuple[list[ParameterMeta], list[ParameterMeta] @classmethod def _reorder_parameter_metas(cls, parameter_metas): - reorder_index = sorted( - range(len(parameter_metas)), - key=lambda i: ( - parameter_metas[i].param_weight_decay, - parameter_metas[i].param_weight_decay == parameter_metas[i].is_tensor_parallel, - parameter_metas[i].param_weight_decay != parameter_metas[i].sequence_tensor_parallel, - ), - ) - reordered_metas = [parameter_metas[i] for i in reorder_index] + def _sort_key(i): + # `param_weight_decay` may be a float; reduce it to a bool so the ordering (and the + # sequence-parallel contiguity it guarantees) follows the decayed/not-decayed split. + weight_decay = bool(parameter_metas[i].param_weight_decay) + return ( + weight_decay, + weight_decay == parameter_metas[i].is_tensor_parallel, + weight_decay != parameter_metas[i].sequence_tensor_parallel, + ) - return reordered_metas + reorder_index = sorted(range(len(parameter_metas)), key=_sort_key) + return [parameter_metas[i] for i in reorder_index] diff --git a/fast_llm/layers/common/linear/config.py b/fast_llm/layers/common/linear/config.py index 803edc302..0917ac401 100644 --- a/fast_llm/layers/common/linear/config.py +++ b/fast_llm/layers/common/linear/config.py @@ -139,6 +139,7 @@ def get_layer( bias = self.bias.get_parameter( (out_dim,), default_initialization=default_bias_initialization, + weight_decay=False, lr_scale=lr_scale, default_enabled=default_add_bias, peft=None, @@ -222,6 +223,7 @@ def get_layer( bias = self.bias.get_parameter( (in_dim,), default_initialization=default_bias_initialization, + weight_decay=False, lr_scale=lr_scale, default_enabled=default_add_bias, peft=peft, diff --git a/fast_llm/layers/common/normalization/normalization.py b/fast_llm/layers/common/normalization/normalization.py index dda12f17b..e7f9d0bb9 100644 --- a/fast_llm/layers/common/normalization/normalization.py +++ b/fast_llm/layers/common/normalization/normalization.py @@ -216,12 +216,14 @@ def __init__(self, config: ConfigType, hidden_dim: TensorDim, lr_scale: float | self.weight = self._config.weight.get_parameter( (hidden_dim,), default_initialization=init_zeros_ if self._config.zero_centered else init_ones_, + weight_decay=False, lr_scale=self._lr_scale, peft=None, ) self.bias = self._config.bias.get_parameter( (hidden_dim,), default_initialization=init_zeros_, + weight_decay=False, lr_scale=self._lr_scale, peft=None, ) @@ -282,6 +284,7 @@ def __init__(self, config: ConfigType, hidden_dim: TensorDim, lr_scale: float | self.weight = self._config.weight.get_parameter( (hidden_dim,), default_initialization=init_zeros_ if self._config.zero_centered else init_ones_, + weight_decay=False, lr_scale=self._lr_scale, peft=None, ) diff --git a/fast_llm/layers/decoder/block.py b/fast_llm/layers/decoder/block.py index b88fc89ec..f8d08e683 100644 --- a/fast_llm/layers/decoder/block.py +++ b/fast_llm/layers/decoder/block.py @@ -127,6 +127,7 @@ def __init__( self.output_scale = self._config.output_scale.get_parameter( (scalar_dim,), default_initialization=init_ones_, + weight_decay=False, lr_scale=self._lr_scale, peft=self._peft, ) diff --git a/fast_llm/layers/decoder/mlp/mixture_of_experts.py b/fast_llm/layers/decoder/mlp/mixture_of_experts.py index d9851d022..b672054a7 100644 --- a/fast_llm/layers/decoder/mlp/mixture_of_experts.py +++ b/fast_llm/layers/decoder/mlp/mixture_of_experts.py @@ -88,12 +88,14 @@ def __init__( self.router_scale = self._config.router_scale.get_parameter( (self._hidden_dim,), default_initialization=init_ones_, + weight_decay=False, lr_scale=self._lr_scale, peft=self._peft, ) self.router_per_expert_scale = self._config.router_per_expert_scale.get_parameter( (TensorDim("experts", self._config.experts),), default_initialization=init_ones_, + weight_decay=False, lr_scale=self._lr_scale, peft=self._peft, ) diff --git a/fast_llm/layers/ssm/gdn.py b/fast_llm/layers/ssm/gdn.py index ef6c6154f..1388f8908 100644 --- a/fast_llm/layers/ssm/gdn.py +++ b/fast_llm/layers/ssm/gdn.py @@ -302,6 +302,7 @@ def __init__( self.dt_bias: ParameterMeta = self._config.dt_bias_weight.get_parameter( (self._value_heads_dim,), default_initialization=init_ones_, + weight_decay=False, lr_scale=self._lr_scale, peft=self._peft, ) @@ -310,6 +311,7 @@ def __init__( default_initialization=LambdaInitializer( lambda _, tensor, generator: tensor.uniform_(0, 16, generator=generator).log_() ), + weight_decay=False, lr_scale=self._lr_scale, peft=self._peft, ) diff --git a/fast_llm/layers/ssm/kda.py b/fast_llm/layers/ssm/kda.py index c59bfe036..dd4f169f4 100644 --- a/fast_llm/layers/ssm/kda.py +++ b/fast_llm/layers/ssm/kda.py @@ -343,6 +343,7 @@ def __init__( self.dt_bias: ParameterMeta = self._config.dt_bias_weight.get_parameter( (self._projection_dim,), default_initialization=init_ones_, + weight_decay=False, lr_scale=self._lr_scale, peft=self._peft, ) @@ -351,6 +352,7 @@ def __init__( default_initialization=LambdaInitializer( lambda _, tensor, generator: tensor.uniform_(1, 16, generator=generator).log_() ), + weight_decay=False, lr_scale=self._lr_scale, peft=self._peft, ) diff --git a/fast_llm/tensor.py b/fast_llm/tensor.py index fbf55e3b2..242b7cebb 100644 --- a/fast_llm/tensor.py +++ b/fast_llm/tensor.py @@ -248,7 +248,7 @@ def __init__( tensor_name: str = "", dims: tuple[TensorDim, ...], init_method: "Initialization | typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], None] | None" = None, - weight_decay: bool = True, + weight_decay: float | bool = True, # Pass a list to split the parameter in contiguous (dim=0) chunks of equal size for optimization. lr_scale: float | None | tuple[float | None, ...] = None, requires_grad: bool = True, @@ -285,7 +285,7 @@ def __new__( tensor_name: str = "", dims: tuple[TensorDim, ...], init_method: "Initializer | typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], None] | None", - weight_decay: bool = True, + weight_decay: float | bool = True, lr_scale: float | None | tuple[float | None, ...] = None, allow_sequence_tensor_parallel: bool = True, allow_no_grad: bool = False,