Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions fast_llm/layers/language_model/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,12 @@ def _validate(self) -> None:
Assert.eq(len(self.prediction_loss_coefficient), self.prediction_heads)
for coeff in self.prediction_loss_coefficient:
Assert.geq(coeff, 0)
if self.transformer.per_layer_lr_scale is not None:
# -1 because the first prediction head's transformer layer is accounted for in num_layers
# +1 because the layer index starts at 1
Assert.eq(
len(self.transformer.per_layer_lr_scale), self.transformer.num_layers + self.prediction_heads - 1 + 1
)

def setup_tensor_space(self, tensor_space: TensorSpace) -> None:
self.transformer.setup_tensor_space(tensor_space)
Expand Down
2 changes: 1 addition & 1 deletion fast_llm/layers/transformer/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def __init__(
super().__init__()
self._config = config
self._tensor_space = tensor_space
Assert.in_range_incl(layer_index, 1, max(self._config.num_layers, 1))
# Assert.in_range_incl(layer_index, 1, max(self._config.num_layers, 1))
self._layer_index = layer_index
self._sequence_parallel = self._tensor_space.distributed_config.sequence_tensor_parallel
self._debug_transformer = self._config.debug_transformer
Expand Down
2 changes: 1 addition & 1 deletion fast_llm/layers/transformer/mixture_of_experts.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s
Assert.gt(config.num_experts, 1)
# TODO: Implement?
assert not config.add_linear_biases, "Biases not supported for MoE."
super().__init__(config, tensor_space, name)
super().__init__(config, tensor_space, name, layer_index)
self._config = config
self._tensor_space = tensor_space
self._debug_mode = self._config.debug_transformer or self._config.debug_transformer_memory
Expand Down
2 changes: 1 addition & 1 deletion fast_llm/layers/transformer/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s
class MLP(MLPBase):
def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", layer_index: int = 0):
Assert.eq(config.num_experts, 1)
super().__init__(config, tensor_space, name)
super().__init__(config, tensor_space, name, layer_index)

def forward(
self,
Expand Down
2 changes: 1 addition & 1 deletion fast_llm/models/gpt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def get_output_layers(self) -> list[Layer]:
self._config.transformer,
self._tensor_space,
# TODO MTP: which index?
layer_index=max(self._config.transformer.num_layers, 1),
layer_index=max(self._config.transformer.num_layers + i, 1),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will have unintended consequences on the initialization scale.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This only affects the prediction-heads for i>0 (thus not the next-token prediction)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but the layer index is used elsewhere. It looks like it's only used in the backup attention regularization though, so it doesn't matter much https://github.com/ServiceNow/Fast-LLM/blob/main/fast_llm/layers/transformer/attention.py#L181. I got mixed up with num_layers which does matter for initialization.

# The last layer only returns the transformer output.
# The previous layers return a stack of shared_hidden and transformer_output.
return_input=i < self._config.prediction_heads - 1,
Expand Down