diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 61a04a752..d6e9ff98b 100644 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -75,6 +75,14 @@ def _verify_params(self): assert self.load_way == "HF", "only support HF format weights" assert self.config["num_key_value_heads"] % self.world_size_ == 0 return + + def load_weights_from_dict(self, weight_dict): + load_hf_weights( + "fp16", + weight_dir=self.weight_dir_, + pre_post_layer=self.pre_post_weight, + transformer_layer_list=self.trans_layers_weight, + weight_dict=weight_dict) def _init_weights(self): self.pre_post_weight = self.pre_and_post_weight_class(self.tp_rank_, self.world_size_, torch.float16, network_config=self.config, mode=self.mode) @@ -88,8 +96,9 @@ def _init_weights(self): pre_post_layer=self.pre_post_weight, transformer_layer_list=self.trans_layers_weight, weight_dict=self.weight_dict) - self.pre_post_weight.verify_load() - [weight.verify_load() for weight in self.trans_layers_weight] + if not self.weight_dict == {}: + self.pre_post_weight.verify_load() + [weight.verify_load() for weight in self.trans_layers_weight] return def _init_mem_manager(self): diff --git a/lightllm/models/llama/model.py b/lightllm/models/llama/model.py index 29583a54c..837a0f9a9 100644 --- a/lightllm/models/llama/model.py +++ b/lightllm/models/llama/model.py @@ -95,8 +95,9 @@ def _init_weights(self): weight_dict=self.weight_dict, prefix='model.layers.', num_layer=self.config["n_layer"]) - self.pre_post_weight.verify_load() - [weight.verify_load() for weight in self.trans_layers_weight] + if not self.weight_dict == {}: + self.pre_post_weight.verify_load() + [weight.verify_load() for weight in self.trans_layers_weight] return def _init_to_get_rotary(self, default_base=10000):