diff --git a/pyvene/models/modeling_utils.py b/pyvene/models/modeling_utils.py index 3edd6ff6..6a374178 100644 --- a/pyvene/models/modeling_utils.py +++ b/pyvene/models/modeling_utils.py @@ -227,7 +227,8 @@ def output_to_subcomponent(output, component, model_type, model_config): :param model_config: Hugging Face Model Config """ subcomponent = output - if component in type_to_module_mapping[model_type]: + if model_type in type_to_module_mapping and \ + component in type_to_module_mapping[model_type]: split_last_dim_by = type_to_module_mapping[model_type][component][2:] if len(split_last_dim_by) != 0 and len(split_last_dim_by) > 2: raise ValueError(f"Unsupported {split_last_dim_by}.")