From 395f263045ba67ddb29f6a9d032b6985a765d9d0 Mon Sep 17 00:00:00 2001 From: Daniel Johnson Date: Sun, 8 Jun 2025 23:35:14 -0700 Subject: [PATCH] Add missing attributes to hugging face conversion functions Fixes https://github.com/google-deepmind/penzai/issues/115 Fixes https://github.com/google-deepmind/penzai/issues/112 Co-authored-by: Eric Alt --- .../models/transformer/variants/gpt_neox.py | 11 +++-- penzai/models/transformer/variants/llama.py | 14 +++++-- .../transformer/variants/llamalike_common.py | 16 +++++-- penzai/models/transformer/variants/mistral.py | 14 +++++-- tests/models/transformer_consistency_test.py | 42 ++++++++++++++++++- 5 files changed, 84 insertions(+), 13 deletions(-) diff --git a/penzai/models/transformer/variants/gpt_neox.py b/penzai/models/transformer/variants/gpt_neox.py index e408ea8..62a67db 100644 --- a/penzai/models/transformer/variants/gpt_neox.py +++ b/penzai/models/transformer/variants/gpt_neox.py @@ -405,13 +405,18 @@ def gpt_neox_from_huggingface_model( "rotary_pct", "vocab_size", # Ignored by conversion: - "max_position_embeddings", - "torch_dtype", + "_attn_implementation_autoset", + "_name_or_path", "architectures", + "attention_probs_dropout_prob", "bos_token_id", "eos_token_id", - "_attn_implementation_autoset", "head_dim", + "hidden_dropout_prob", + "is_decoder", + "max_position_embeddings", + "torch_dtype", + "type_vocab_size", } bad_attributes = {} for k, v in hf_config_attributes.items(): diff --git a/penzai/models/transformer/variants/llama.py b/penzai/models/transformer/variants/llama.py index 1183c9d..97b49a2 100644 --- a/penzai/models/transformer/variants/llama.py +++ b/penzai/models/transformer/variants/llama.py @@ -66,6 +66,7 @@ def llama_from_huggingface_model( reference_attributes = transformers.LlamaConfig().to_dict() handled_or_ignored_attributes = { # Handled during conversion: + "hidden_act", "hidden_size", "intermediate_size", "num_attention_heads", @@ -75,13 +76,20 @@ def llama_from_huggingface_model( "rope_theta", "vocab_size", # Ignored by conversion: - "max_position_embeddings", - "torch_dtype", + "_attn_implementation_autoset", + "_name_or_path", "architectures", + "attention_probs_dropout_prob", "bos_token_id", "eos_token_id", - "_attn_implementation_autoset", "head_dim", + "hidden_dropout_prob", + "is_decoder", + "max_position_embeddings", + "pad_token_id", + "torch_dtype", + "type_vocab_size", + "use_cache", } bad_attributes = {} for k, v in hf_config_attributes.items(): diff --git a/penzai/models/transformer/variants/llamalike_common.py b/penzai/models/transformer/variants/llamalike_common.py index b1247bc..321cd13 100644 --- a/penzai/models/transformer/variants/llamalike_common.py +++ b/penzai/models/transformer/variants/llamalike_common.py @@ -118,7 +118,7 @@ class LlamalikeTransformerConfig: mlp_hidden_dim: int num_decoder_blocks: int vocab_size: int - mlp_variant: Literal["geglu_approx", "swiglu"] + mlp_variant: Literal["geglu_exact", "geglu_approx", "swiglu"] tie_embedder_and_logits: bool rope_wavelength: float = 10_000 rms_norm_eps: float = 1e-6 @@ -157,7 +157,9 @@ def build_llamalike_feedforward( Returns: An instance of TransformerFeedForward containing the GELU MLP blocks. """ - if config.mlp_variant == "geglu_approx": + if config.mlp_variant == "geglu_exact": + act_fn = functools.partial(jax.nn.gelu, approximate=False) + elif config.mlp_variant == "geglu_approx": # Approximate is already the default in JAX, but we specify it explicitly # because defaults differ between JAX and PyTorch. act_fn = functools.partial(jax.nn.gelu, approximate=True) @@ -641,6 +643,14 @@ def llamalike_from_huggingface_model( else: activation_dtype = param_dtype + # Map HuggingFace hidden_act to Penzai mlp_variant + hidden_act_to_mlp_variant = { + "silu": "swiglu", + "gelu": "geglu_exact", + "gelu_new": "geglu_approx", + } + mlp_variant = hidden_act_to_mlp_variant[hf_config.hidden_act] + pz_config = LlamalikeTransformerConfig( num_kv_heads=num_kv_heads, query_head_multiplier=query_head_multiplier, @@ -649,7 +659,7 @@ def llamalike_from_huggingface_model( mlp_hidden_dim=hf_config.intermediate_size, num_decoder_blocks=hf_config.num_hidden_layers, vocab_size=hf_config.vocab_size, - mlp_variant="swiglu", + mlp_variant=mlp_variant, rope_wavelength=hf_config.rope_theta, tie_embedder_and_logits=False, attention_type=attention_type, diff --git a/penzai/models/transformer/variants/mistral.py b/penzai/models/transformer/variants/mistral.py index c543b84..72d61ab 100644 --- a/penzai/models/transformer/variants/mistral.py +++ b/penzai/models/transformer/variants/mistral.py @@ -71,6 +71,7 @@ def mistral_from_huggingface_model( reference_attributes = transformers.MistralConfig().to_dict() handled_or_ignored_attributes = { # Handled during conversion: + "hidden_act", "hidden_size", "intermediate_size", "num_attention_heads", @@ -81,11 +82,18 @@ def mistral_from_huggingface_model( "vocab_size", "sliding_window", # Ignored by conversion: - "max_position_embeddings", - "torch_dtype", - "architectures", "_attn_implementation_autoset", + "_name_or_path", + "architectures", + "attention_probs_dropout_prob", "head_dim", + "hidden_dropout_prob", + "is_decoder", + "max_position_embeddings", + "pad_token_id", + "torch_dtype", + "type_vocab_size", + "use_cache", } bad_attributes = {} for k, v in hf_config_attributes.items(): diff --git a/tests/models/transformer_consistency_test.py b/tests/models/transformer_consistency_test.py index cc6a166..9f9c710 100644 --- a/tests/models/transformer_consistency_test.py +++ b/tests/models/transformer_consistency_test.py @@ -36,12 +36,24 @@ class TransformerConsistencyTest(parameterized.TestCase): ) def test_llama_consistency(self, num_attention_heads, num_key_value_heads): cfg = transformers.LlamaConfig( + # Adjusted architecture parameters for a smaller version of Llama. vocab_size=11, hidden_size=64, intermediate_size=256, num_hidden_layers=3, num_attention_heads=num_attention_heads, num_key_value_heads=num_key_value_heads, + # Extra parameters that are set when loading official models from + # HuggingFace but aren't set by default in LlamaConfig. + max_position_embeddings=8192, + rms_norm_eps=1e-05, + rope_theta=500000.0, + torch_dtype="bfloat16", + architectures=["LlamaForCausalLM"], + bos_token_id=128000, + eos_token_id=128001, + _name_or_path="meta-llama/Meta-Llama-3-8B", + _attn_implementation_autoset=True, ) torch.manual_seed(0) @@ -73,15 +85,35 @@ def test_llama_consistency(self, num_attention_heads, num_key_value_heads): dict(testcase_name="full", num_attention_heads=4, num_key_value_heads=4), dict(testcase_name="mqa", num_attention_heads=4, num_key_value_heads=1), dict(testcase_name="gqa", num_attention_heads=4, num_key_value_heads=2), + dict( + testcase_name="act_gelu", + num_attention_heads=4, + num_key_value_heads=4, + hidden_act="gelu", + ), ) - def test_mistral_consistency(self, num_attention_heads, num_key_value_heads): + def test_mistral_consistency( + self, num_attention_heads, num_key_value_heads, hidden_act="silu" + ): cfg = transformers.MistralConfig( + # Adjusted architecture parameters for a smaller version of Mistral. vocab_size=11, hidden_size=64, intermediate_size=256, num_hidden_layers=3, num_attention_heads=num_attention_heads, num_key_value_heads=num_key_value_heads, + hidden_act=hidden_act, + # Extra parameters that are set when loading official models from + # HuggingFace but aren't set by default in MistralConfig. + max_position_embeddings=32768, + sliding_window=None, + rms_norm_eps=1e-05, + rope_theta=1000000.0, + torch_dtype="bfloat16", + architectures=["MistralForCausalLM"], + _name_or_path="fake_org/fake-Mistral-version", + _attn_implementation_autoset=True, ) torch.manual_seed(0) @@ -110,11 +142,19 @@ def test_mistral_consistency(self, num_attention_heads, num_key_value_heads): def test_gpt_neox_consistency(self): cfg = transformers.GPTNeoXConfig( + # Adjusted architecture parameters for a smaller version of GPT-NeoX. vocab_size=11, hidden_size=64, intermediate_size=256, num_hidden_layers=3, num_attention_heads=4, + # Extra parameters that are set when loading official models from + # HuggingFace but aren't set by default in GPTNeoXConfig. + torch_dtype="float16", + architectures=["GPTNeoXForCausalLM"], + eos_token_id=0, + _name_or_path="fake_org/fake-GPTNeoX-version", + _attn_implementation_autoset=True, ) torch.manual_seed(0)