@@ -32,6 +32,8 @@ def euc(query, key, config, act_mask=None, infer=False):
3232
3333 act_fn = ACT2FN [config .hidden_act ]
3434 l2_norm = torch .norm (act_fn (key ) - act_fn (query ), dim = - 1 )
35+ if l2_norm .dim () == 1 :
36+ l2_norm = l2_norm .unsqueeze (0 )
3537 if infer and l2_norm .size (1 ) > 100 :
3638 topk = torch .topk (l2_norm , k = 1 , largest = True )
3739 return topk .values .mean ()
@@ -74,7 +76,20 @@ def __init__(self, config, model, device):
7476 self .layer_name = self .layer .rsplit ("." , 1 )[- 1 ]
7577 adapter_layer = getattr (self .edit_module , self .layer_name )
7678
77- if type (adapter_layer ) is not WISEAdapter :
79+ # if the condition below is True, then it is single-edit
80+ if not config .sequential_edit :
81+ # if type(adapter_layer) is not WISEAdapter:
82+ # 如果 adapter_layer 已经是 WISEAdapter,提取其原始层
83+ if type (adapter_layer ) is WISEAdapter :
84+ # 使用 original_layer 作为基础层(这是保存的原始层副本)
85+ base_layer = adapter_layer .original_layer
86+ else :
87+ base_layer = adapter_layer
88+
89+ setattr (self .edit_module , self .layer_name , WISEAdapter (config , base_layer , transpose = transpose ))
90+ self .original_layer = copy .deepcopy (base_layer )
91+ print (f"New weights successfully inserted into { layer } " )
92+ elif type (adapter_layer ) is not WISEAdapter :
7893 setattr (self .edit_module , self .layer_name , WISEAdapter (config , adapter_layer , transpose = transpose ))
7994 self .original_layer = copy .deepcopy (adapter_layer )
8095 print (f"New weights successfully inserted into { layer } " )
@@ -84,16 +99,27 @@ def __init__(self, config, model, device):
8499 gc .collect ()
85100
86101 # Forward
87- def __call__ (self , ** kwargs ):
102+ def __call__ (self , * args , * *kwargs ):
88103 if not self .config .retrieve :
89- if hasattr (self .get_adapter_layer (), 'editing' ) and not self .get_adapter_layer ().editing :
90- # final merge
91- if not self .get_adapter_layer ().original_layer .weight .equal (self .get_adapter_layer ().new_weight ) and self .get_adapter_layer ().editing_total_cnt >= self .config .save_freq :
92- self .get_adapter_layer ().memory_weight .append (self .get_adapter_layer ().new_weight )
93- if len (self .get_adapter_layer ().memory_weight ) > 0 and self .get_adapter_layer ().editing_total_cnt >= self .config .save_freq :
94- print ('length of memory is ' , len (self .get_adapter_layer ().memory_weight ), '!!!!!!' )
95- self .get_adapter_layer ().merge_weight ()
96- return self .model (** kwargs )
104+ adapter = self .get_adapter_layer ()
105+ if hasattr (adapter , 'editing' ) and not adapter .editing :
106+ if (not adapter .original_layer .weight .equal (adapter .new_weight )
107+ and adapter .editing_total_cnt >= self .config .save_freq ):
108+ adapter .memory_weight .append (adapter .new_weight )
109+
110+ if len (adapter .memory_weight ) > 0 and adapter .editing_total_cnt >= self .config .save_freq :
111+ print ('length of memory is ' , len (adapter .memory_weight ), '!!!!!!' )
112+ adapter .merge_weight ()
113+ # 1. 如果用户传入 model(batch)
114+ if len (args ) == 1 and isinstance (args [0 ], dict ):
115+ return self .model (args [0 ])
116+ # 2. 如果用户传入 model(batch=batch)
117+ elif "batch" in kwargs and isinstance (kwargs ["batch" ], dict ):
118+ batch = kwargs .pop ("batch" )
119+ return self .model (** batch , ** kwargs )
120+ # 3. 普通 HuggingFace 风格,如 model(input_ids=..., pixel_values=...)
121+ else :
122+ return self .model (** kwargs )
97123
98124 def reset_layer (self ):
99125 layer = getattr (self .edit_module , self .layer_name )
@@ -257,7 +283,14 @@ def _cal_activation_loss(self, original_layer_output, new_weight_layer_output, c
257283 else :
258284 k = 1
259285 total_loss = []
286+ if self .config .model_name == "blip2" :
287+ original_layer_output = original_layer_output .reshape (2 , - 1 , original_layer_output .size (- 1 ))
288+ new_weight_layer_output = new_weight_layer_output .reshape (2 , - 1 , new_weight_layer_output .size (- 1 ))
260289 len_temp = original_layer_output .shape [0 ] / k - 1
290+ # if len_temp == 0:
291+ # len_temp = 1
292+ # print(len_temp)
293+ # print(act_mask)
261294 for i ,act_mk in enumerate (act_mask ):
262295 if act_mk is not None :
263296 in_scope_dist = euc (original_layer_output [int (i * len_temp ):int ((i + 1 )* len_temp ), ...], new_weight_layer_output [int (i * len_temp ):int ((i + 1 )* len_temp ), ...], config ,
@@ -270,7 +303,8 @@ def _cal_activation_loss(self, original_layer_output, new_weight_layer_output, c
270303 out_scope_dist = euc (original_layer_output [int (i - k ):, ...], new_weight_layer_output [int (i - k ):, ...], config )
271304 else :
272305 out_scope_dist = euc (original_layer_output [int (i - k ):int (i + 1 - k ), ...], new_weight_layer_output [int (i - k ):int (i + 1 - k ), ...], config )
273-
306+ # print("in_scope_dist: ", in_scope_dist)
307+ # print("out_scope_dist: ", out_scope_dist)
274308 loss = out_scope_dist .view (- 1 ,1 ) - in_scope_dist + config .gamma
275309 loss2 = out_scope_dist - config .alpha
276310 loss3 = config .beta - in_scope_dist
@@ -539,8 +573,12 @@ def edit(self, config, multimodal_inputs, text_tokens, ans_token_len, act_mask=N
539573 ft_loss = self ._cal_ft_loss (multimodal_inputs , text_tokens , last_prompt_token_loc , ans_token_len )
540574
541575 act_loss = super ()._cal_activation_loss (super ().get_adapter_layer ().original_layer_output , super ().get_adapter_layer ().new_weight_layer_output ,
542- config = config , act_mask = act_mask , deact_mask = deact_mask )
576+ config = config , act_mask = act_mask , deact_mask = deact_mask )
543577 loss = ft_loss + act_loss .to (ft_loss .device )
578+ # if self.config.model_name == "blip2":
579+ # print(self.model.generate(multimodal_inputs[0]))
580+ # elif self.config.model_name == "minigpt4":
581+ # print(self.model.predict_answers(multimodal_inputs))
544582
545583 if loss_meter .stop ():
546584 super ().get_adapter_layer ().save_editing_activation () # add last gradient
@@ -632,10 +670,17 @@ def _cal_ft_loss(self, multimodal_inputs, text_tokens, last_prompt_token_loc, an
632670 if k != 1 :
633671 raise AssertionError ("Not support Batch Edit" )
634672
635- bs = text_tokens ["input_ids" ].shape [0 ] - k
636- logits = self .model (** multimodal_inputs ).logits
637- shift_logits = logits [:- k , :- 1 , :].contiguous ()
638- shift_labels = multimodal_inputs ['input_ids' ][:- k , 1 :].contiguous ()
673+ if self .config .model_name == "blip2" or self .config .model_name == "minigpt4" :
674+ logits = self .model (multimodal_inputs ).logits
675+ labels = text_tokens ["labels" ]
676+ shift_labels = labels [:, 1 :].contiguous ()
677+ shift_logits = logits [:- k , :- 1 , :].contiguous ()
678+ bs = text_tokens ["labels" ].shape [0 ]
679+ else :
680+ logits = self .model (** multimodal_inputs ).logits
681+ shift_labels = multimodal_inputs ['input_ids' ][:- k , 1 :].contiguous ()
682+ shift_logits = logits [:- k , :- 1 , :].contiguous ()
683+ bs = text_tokens ["input_ids" ].shape [0 ] - k
639684 # only cal loss of target text tokens
640685 loss_fct = CrossEntropyLoss (reduction = 'none' )
641686 a = shift_logits .view (- 1 , shift_logits .size (- 1 ))
@@ -645,5 +690,4 @@ def _cal_ft_loss(self, multimodal_inputs, text_tokens, last_prompt_token_loc, an
645690 loss = loss .view (bs , - 1 )
646691 label_mask = torch .ones_like (loss , dtype = torch .bool )
647692 ft_loss = ((loss * label_mask ).sum (1 ) / label_mask .sum (1 )).mean ()
648- return ft_loss
649-
693+ return ft_loss
0 commit comments