Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions mergekit/_data/architectures/kormo.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
{
"model_type": "kormo",
"architectures": [
"KORMoForCausalLM"
],
"pre_weights": [
{
"name": "model.embed_tokens.weight",
"is_embed": true
}
],
"post_weights": [
{
"name": "model.norm.weight"
},
{
"name": "lm_head.weight",
"is_embed": true,
"optional": true,
"tied_names": [
"model.embed_tokens.weight"
]
}
],
"num_layers_config_key": "num_hidden_layers",
"layer_templates": {
"weights": [
{
"name": "model.layers.${layer_index}.pre_attention_layernorm.weight"
},
{
"name": "model.layers.${layer_index}.self_attn.q_proj.weight"
},
{
"name": "model.layers.${layer_index}.self_attn.k_proj.weight"
},
{
"name": "model.layers.${layer_index}.self_attn.v_proj.weight"
},
{
"name": "model.layers.${layer_index}.self_attn.o_proj.weight"
},
{
"name": "model.layers.${layer_index}.pre_mlp_layernorm.weight"
},
{
"name": "model.layers.${layer_index}.mlp.gate_proj.weight"
},
{
"name": "model.layers.${layer_index}.mlp.up_proj.weight"
},
{
"name": "model.layers.${layer_index}.mlp.down_proj.weight"
}
]
}
}
48 changes: 48 additions & 0 deletions mergekit/_data/architectures/kormo_moe.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
{
"model_type": "kormo_moe",
"architectures": [
"KORMoMoeForCausalLM"
],
"pre_weights": [
{
"name": "model.embed_tokens.weight",
"is_embed": true
}
],
"post_weights": [
{
"name": "model.norm.weight"
},
{
"name": "lm_head.weight",
"is_embed": true,
"optional": true,
"tied_names": [
"model.embed_tokens.weight"
]
}
],
"num_layers_config_key": "num_hidden_layers",
"layer_templates": {
"weights": [
{
"name": "model.layers.${layer_index}.pre_attention_layernorm.weight"
},
{
"name": "model.layers.${layer_index}.self_attn.q_proj.weight"
},
{
"name": "model.layers.${layer_index}.self_attn.k_proj.weight"
},
{
"name": "model.layers.${layer_index}.self_attn.v_proj.weight"
},
{
"name": "model.layers.${layer_index}.self_attn.o_proj.weight"
},
{
"name": "model.layers.${layer_index}.pre_mlp_layernorm.weight"
}
]
}
}
9 changes: 9 additions & 0 deletions mergekit/architecture/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from mergekit.architecture.moe_defs import (
MixtralModuleArchitecture,
Qwen3MoeModuleArchitecture,
KORMoMoeModuleArchitecture
)
from mergekit.options import MergeOptions

Expand All @@ -34,6 +35,7 @@
def arch_info_for_config(config: PretrainedConfig) -> Optional[ModelArchitecture]:
if len(config.architectures) != 1:
raise RuntimeError("More than one architecture in config?")

arch_name = config.architectures[0]

if arch_name == MixtralModuleArchitecture.ARCHITECTURE_NAME:
Expand All @@ -50,6 +52,13 @@ def arch_info_for_config(config: PretrainedConfig) -> Optional[ModelArchitecture
architectures=[arch_name],
model_type="qwen3_moe",
)
elif arch_name == KORMoMoeModuleArchitecture.ARCHITECTURE_NAME: # 추가
module = KORMoMoeModuleArchitecture.from_config(config)
return ModelArchitecture(
modules={"default": ModuleDefinition(architecture=module)},
architectures=[arch_name],
model_type="kormo_moe",
)
elif arch_name in NAME_TO_ARCH:
candidates = list(NAME_TO_ARCH[arch_name])
if len(candidates) == 1:
Expand Down
79 changes: 78 additions & 1 deletion mergekit/architecture/moe_defs.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,19 +85,96 @@ def num_layers_config_key(self) -> str:
def layer_weights(
self, index: int, config: PretrainedConfig
) -> Optional[List[WeightInfo]]:
num_experts = self.num_experts
prefix = f"model.layers.{index}"
tensor_names = []
for expert_idx in range(self.num_experts):

# Expert weights 추가
for expert_idx in range(num_experts):
Comment on lines +91 to +93
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The parameter order in the expert loops is inconsistent between implementations. In Qwen3MoeModuleArchitecture, the order is up_proj, gate_proj, down_proj, while in KORMoMoeModuleArchitecture it's gate_proj, up_proj, down_proj. This inconsistency should be standardized to ensure correct weight processing across different model architectures. Consider aligning the parameter order in both implementations to maintain consistency throughout the codebase.

Spotted by Graphite Agent

Fix in Graphite


Is this helpful? React 👍 or 👎 to let us know.

for param in ("up_proj", "gate_proj", "down_proj"):
tensor_names.append(
prefix + f".mlp.experts.{expert_idx}.{param}.weight"
)

# Shared expert weights 추가 - 이 부분이 중요!
for param in ("up_proj", "gate_proj", "down_proj"):
tensor_names.append(
prefix + f".mlp.shared_expert.{param}.weight"
)

# Gate weights 추가
tensor_names.append(prefix + ".mlp.gate.weight")
tensor_names.append(prefix + ".mlp.shared_expert_gate.weight")
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Layer Weights Naming Conflict

The Qwen3MoeModuleArchitecture.layer_weights method unconditionally adds shared expert and shared expert gate weight names. This conflicts with write_model's conditional writing of these weights, potentially causing them to be missing. Additionally, the declared names clash with write_model's path transformation, leading to malformed weight paths.

Fix in Cursor Fix in Web


res = []
for name in tensor_names:
res.append(WeightInfo(name=name))

# 기존 Qwen3 weights 중에서 MLP를 제외한 것들 추가
for weight_info in QWEN3_MODULE_ARCH.layer_weights(index, config):
if ".mlp." in weight_info.name:
continue
res.append(weight_info)

return res

# 파일 상단 import 부분에 추가
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Mixed Language Comments Cause Codebase Confusion

Korean language development comments were accidentally committed, creating inconsistency with the English codebase and potentially confusing international contributors.

Additional Locations (2)

Fix in Cursor Fix in Web

KORMO_INFO = NAME_TO_ARCH["KORMoForCausalLM"][0]
KORMO_MODULE_ARCH = KORMO_INFO.modules["default"].architecture


class KORMoMoeModuleArchitecture(ModuleArchitecture, BaseModel):
ARCHITECTURE_NAME: ClassVar[str] = "KORMoMoeForCausalLM"
num_experts: int

def name(self) -> str:
return "kormo_moe"

@classmethod
def from_config(cls, config: PretrainedConfig):
return KORMoMoeModuleArchitecture(num_experts=config.num_experts)

def pre_weights(self, config: PretrainedConfig) -> List[WeightInfo]:
return KORMO_MODULE_ARCH.pre_weights(config)

def post_weights(self, config: PretrainedConfig) -> List[WeightInfo]:
return KORMO_MODULE_ARCH.post_weights(config)

def num_layers_config_key(self) -> str:
return KORMO_MODULE_ARCH.num_layers_config_key()

def layer_weights(
self, index: int, config: PretrainedConfig
) -> Optional[List[WeightInfo]]:
num_experts = self.num_experts
prefix = f"model.layers.{index}"
tensor_names = []

# Expert weights 추가
for expert_idx in range(num_experts):
for param in ("gate_proj", "up_proj", "down_proj"):
tensor_names.append(
prefix + f".mlp.experts.{expert_idx}.{param}.weight"
)

# Shared expert weights 추가
for param in ("gate_proj", "up_proj", "down_proj"):
tensor_names.append(
prefix + f".mlp.shared_expert.{param}.weight"
)

# Gate weights 추가
tensor_names.append(prefix + ".mlp.gate.weight")
tensor_names.append(prefix + ".mlp.shared_expert_gate.weight")

res = []
for name in tensor_names:
res.append(WeightInfo(name=name))

# 기존 KORMo weights 중에서 MLP를 제외한 것들 추가
for weight_info in KORMO_MODULE_ARCH.layer_weights(index, config):
if ".mlp." in weight_info.name:
continue
res.append(weight_info)

return res
19 changes: 18 additions & 1 deletion mergekit/moe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,31 @@

ALL_OUTPUT_ARCHITECTURES: List[MoEOutputArchitecture] = [MixtralMoE(), DeepseekMoE()]

# Qwen3MoE를 먼저 추가
try:
from mergekit.moe.qwen3 import Qwen3MoE
except ImportError:
pass
else:
ALL_OUTPUT_ARCHITECTURES.append(Qwen3MoE())

# QwenMoE를 나중에 추가 (fallback용)
try:
from mergekit.moe.qwen import QwenMoE
except ImportError:
pass
else:
ALL_OUTPUT_ARCHITECTURES.append(QwenMoE())

# KORMo MoE 추가
try:
from mergekit.moe.kormo import KORMoMoE
except ImportError:
pass
else:
ALL_OUTPUT_ARCHITECTURES.append(KORMoMoE())

__all__ = [
"ALL_OUTPUT_ARCHITECTURES",
"MoEOutputArchitecture",
]
]
86 changes: 86 additions & 0 deletions mergekit/moe/_architectures/configuration_kormo_moe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# <저장된_모델_경로>/configuration_kormo_moe.py

from transformers import PretrainedConfig
from transformers.modeling_rope_utils import rope_config_validation


class KORMoMoeConfig(PretrainedConfig):
model_type = "kormo_moe"
keys_to_ignore_at_inference = ["past_key_values"]

def __init__(
self,
vocab_size=112576,
hidden_size=6144,
intermediate_size=21504,
num_hidden_layers=48,
num_attention_heads=40,
num_key_value_heads=8,
hidden_act="silu",
max_position_embeddings=131072,
initializer_range=0.02,
rms_norm_eps=1e-05,
use_cache=True,
pad_token_id=None,
bos_token_id=0,
eos_token_id=1,
pretraining_tp=1,
tie_word_embeddings=False,
rope_theta=500000.0,
attention_bias=False,
attention_dropout=0.0,
rope_scaling=None,
mlp_bias=False,
head_dim=128,
# MoE specific
num_experts=2,
num_experts_per_tok=2,
moe_intermediate_size=None,
shared_expert_intermediate_size=None,
norm_topk_prob=True,
decoder_sparse_step=1,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads

if num_key_value_heads is None:
num_key_value_heads = num_attention_heads

self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.pretraining_tp = pretraining_tp
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.mlp_bias = mlp_bias
self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
self.mask_type = None

# MoE specific
self.num_experts = num_experts
self.num_experts_per_tok = num_experts_per_tok
self.moe_intermediate_size = moe_intermediate_size if moe_intermediate_size is not None else intermediate_size
self.shared_expert_intermediate_size = shared_expert_intermediate_size
self.norm_topk_prob = norm_topk_prob
self.decoder_sparse_step = decoder_sparse_step

if self.rope_scaling is not None and "type" in self.rope_scaling:
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
rope_config_validation(self)

super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
Loading
Loading