Skip to content

Commit a642707

Browse files
authored
Merge pull request #639 from falling-leaf/PR-branch
PR for updating MEND, WISE method on blip2, minigpt4, qwen2-vl, llava-ov
2 parents 25a9ec4 + ce905bc commit a642707

23 files changed

+1224
-230
lines changed

easyeditor/dataset/coco_caption.py

Lines changed: 182 additions & 73 deletions
Large diffs are not rendered by default.

easyeditor/dataset/processor/blip_processors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def __init__(self, image_size=384, mean=None, std=None):
144144
)
145145

146146
def __call__(self, item, file_type=None):
147-
item = Image.open(item)
147+
item = Image.open(item).convert("RGB")
148148
return self.transform(item)
149149

150150
@classmethod

easyeditor/dataset/vqa.py

Lines changed: 189 additions & 74 deletions
Large diffs are not rendered by default.

easyeditor/editors/multimodal_editor.py

Lines changed: 39 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -136,19 +136,23 @@ def __init__(self,
136136
self.rephrase_root = hparams.rephrase_image
137137

138138
elif "llava-onevision" in hparams.model_name.lower():
139+
if not hasattr(hparams, 'dtype'):
140+
hparams.dtype = torch.float32
139141
self.model = LlavaOnevisionForConditionalGeneration.from_pretrained(
140142
hparams.model_name,
141-
torch_dtype=torch.float32,
143+
torch_dtype=hparams.dtype,
142144
# attn_implementation="flash_attention_2"
143145
)
144146
self.vis_tok = LLaVAOneVisionProcessor()
145147
self.tok = AutoProcessor.from_pretrained(hparams.model_name)
146148
self.model_name = "llava-onevision"
147149

148150
elif "qwen2-vl" in hparams.model_name.lower():
151+
if not hasattr(hparams, 'dtype'):
152+
hparams.dtype = torch.float32
149153
self.model = Qwen2VLForConditionalGeneration.from_pretrained(
150154
hparams.model_name,
151-
torch_dtype=torch.float32,
155+
torch_dtype=hparams.dtype,
152156
# attn_implementation="flash_attention_2"
153157
)
154158
self.vis_tok = Qwen2VLProcessor()
@@ -443,10 +447,16 @@ def edit_dataset(self,
443447
train_ds=kwargs['train_ds']
444448
)
445449
else:
450+
if self.model_name in ['minigpt4', 'blip2']:
451+
pre_res = compute_multimodal_edit_results(self.model, self.model_name, self.hparams, self.tok,
452+
request, self.hparams.device)
453+
elif self.model_name in ['llava-onevision', 'qwen2-vl']:
454+
pre_res = compute_multimodal_hf_edit_results(self.model, self.model_name, self.hparams, self.tok,
455+
request, self.hparams.device)
446456
edited_model, weights_copy = self.apply_algo(
447457
self.model,
448458
self.tok,
449-
request,
459+
[request],
450460
self.hparams,
451461
copy=False,
452462
return_orig_weights=True,
@@ -466,15 +476,32 @@ def edit_dataset(self,
466476
request, self.hparams.device, pre_edit=True)
467477
}
468478
else:
469-
metrics = {
470-
'case_id': i,
471-
# "requested_rewrite": request,
472-
"time": exec_time,
473-
"post": compute_multimodal_edit_results(edited_model, self.model_name, self.hparams, self.tok,
474-
request, self.hparams.device),
475-
"pre": compute_multimodal_edit_results(self.model, self.model_name, self.hparams, self.tok,
476-
request, self.hparams.device)
477-
}
479+
if self.model_name in ['minigpt4', 'blip2']:
480+
metrics = {
481+
'case_id': i,
482+
"time": exec_time,
483+
"post": compute_multimodal_edit_results(edited_model, self.model_name, self.hparams, self.tok,
484+
request, self.hparams.device),
485+
"pre": pre_res
486+
}
487+
elif self.model_name in ['llava-onevision', 'qwen2-vl']:
488+
metrics = {
489+
'case_id': i,
490+
# "requested_rewrite": request,
491+
"time": exec_time,
492+
"post": compute_multimodal_hf_edit_results(edited_model, self.model_name, self.hparams, self.tok,
493+
request, self.hparams.device),
494+
"pre": pre_res
495+
}
496+
# metrics = {
497+
# 'case_id': i,
498+
# # "requested_rewrite": request,
499+
# "time": exec_time,
500+
# "post": compute_multimodal_edit_results(edited_model, self.model_name, self.hparams, self.tok,
501+
# request, self.hparams.device),
502+
# "pre": compute_multimodal_edit_results(self.model, self.model_name, self.hparams, self.tok,
503+
# request, self.hparams.device)
504+
# }
478505
if 'locality_output' in metrics['post'].keys():
479506
assert len(metrics['post']['locality_output']) == \
480507
len(metrics['pre']['locality_output'])
@@ -648,7 +675,6 @@ def _prepare_requests(self,
648675
'multimodal_locality_ground_truth': multimodal_locality_ground_truth[i],
649676
}
650677
)
651-
652678
if 'loc_prompts' in kwargs:
653679
if isinstance(kwargs['loc_prompts'], str):
654680
kwargs['loc_prompts'] = [kwargs['loc_prompts'],]

easyeditor/evaluate/multimodal_evaluate.py

Lines changed: 57 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -217,15 +217,18 @@ def prepare_multimodal_hf_edit(hparams,
217217
add_generation_prompt=True,
218218
tokenize=False) + l
219219
for p, l in zip(prompts, targets)]
220+
if "qwen2-vl" in hparams.model_name.lower() and "|vision_start|" not in text_input[0]:
221+
image_token = "<|vision_start|><|image_pad|><|vision_end|>"
222+
text_input = [image_token + text_input[0]]
220223
else:
221224
raise AssertionError("Not support file type: {}".format(file_type))
222225

223226
if file_type in ["image", "single-image", "multi-image"]:
224-
multimodal_inputs = processor(images=image, text=text_input, return_tensors="pt").to(hparams.device, dtype=torch.float32)
227+
multimodal_inputs = processor(images=image, text=text_input, return_tensors="pt").to(hparams.device, dtype=hparams.dtype)
225228
elif file_type == "video":
226-
multimodal_inputs = processor(videos=image, text=text_input, return_tensors="pt").to(hparams.device, dtype=torch.float32)
229+
multimodal_inputs = processor(videos=image, text=text_input, return_tensors="pt").to(hparams.device, dtype=hparams.dtype)
227230
elif file_type == "text":
228-
multimodal_inputs = processor(text=text_input, return_tensors="pt").to(hparams.device, dtype=torch.float32)
231+
multimodal_inputs = processor(text=text_input, return_tensors="pt").to(hparams.device, dtype=hparams.dtype)
229232

230233
targets = processor.tokenizer(targets, add_special_tokens=False,
231234
return_tensors="pt", padding=True, max_length=multimodal_inputs["input_ids"].size(1))["input_ids"]
@@ -271,6 +274,43 @@ def compute_multimodal_hf_edit_quality(model, batch, tok,exach_match=False):
271274
pred_ids = pred_ids.masked_select(pred_ids != 0).view(1, -1)
272275
return acc, pred_ids.numpy()
273276

277+
def compute_multimodal_hf_edit_quality_demo(model, batch, tok, exach_match=False):
278+
with torch.no_grad():
279+
outputs = model(**batch["multimodal_inputs"])
280+
if isinstance(outputs, torch.Tensor):
281+
logits = outputs.detach().cpu()
282+
targ = batch["labels"].cpu()
283+
else:
284+
logits = outputs.logits.detach().cpu()
285+
targ = batch["labels"].cpu()
286+
287+
# 创建logits副本 - 这是demo版本的关键区别
288+
logits_ = logits.clone()
289+
290+
if logits.dim() == 3:
291+
logits = logits[:, :-1, :]
292+
targ = targ[:, 1:]
293+
294+
mask = targ != -100
295+
targ[~mask] = 0
296+
if exach_match:
297+
pred_ids = logits.argmax(-1).masked_fill(~mask, 0)
298+
correct = pred_ids == targ
299+
if logits.dim() == 3:
300+
correct = (pred_ids == targ).all(-1) # We aim for an exact match across the entire sequence
301+
acc = correct.float().mean()
302+
else:
303+
pred_ids = logits.argmax(-1).masked_fill(~mask, 0).detach().cpu()
304+
correct = pred_ids == targ
305+
correct = correct & mask
306+
num_non_padding = mask.sum().float().item()
307+
acc = correct.sum() / num_non_padding
308+
309+
pred_ids = pred_ids.masked_select(pred_ids != 0).view(1, -1)
310+
311+
# demo版本返回完整的logits用于进一步分析
312+
return acc, pred_ids.numpy(), logits_
313+
274314

275315
def compute_multimodal_edit_quality(model, batch, exact_match=False):
276316
with torch.no_grad():
@@ -360,7 +400,16 @@ def compute_multimodal_edit_results(
360400

361401
target = record["target"]
362402
rewrite_prompts = record["prompt"]
363-
image = record["image"] if record["image"].is_cuda else record["image"].to(hparams.device)
403+
# image = record["image"] if record["image"].is_cuda else record["image"].to(hparams.device)
404+
405+
# 由于edit_dataset无prepare,因此request
406+
if hasattr(record["image"], 'is_cuda'): # 如果是PyTorch张量
407+
image = record["image"] if record["image"].is_cuda else record["image"].to(hparams.device)
408+
else: # 如果是PIL图像或其他类型
409+
# 需要先将PIL图像转换为张量
410+
from torchvision import transforms
411+
transform = transforms.ToTensor()
412+
image = transform(record["image"]).to(hparams.device)
364413

365414
edit_inner = prepare_multimodal_edit(hparams, tok, target, rewrite_prompts, image)
366415
ret['rewrite_acc'], _ = compute_multimodal_edit_quality(model, edit_inner)
@@ -439,14 +488,16 @@ def compute_multimodal_hf_edit_results(
439488
locality_prompt = record["locality_prompt"]
440489
locality_ground_truth = record["locality_ground_truth"]
441490
locality = prepare_multimodal_hf_edit(hparams, tok, locality_ground_truth, locality_prompt, None, file_type="text")
442-
_, ret['locality_output'] = compute_multimodal_hf_edit_quality(model, locality, tok)
491+
# _, ret['locality_output'] = compute_multimodal_hf_edit_quality(model, locality, tok)
492+
_, _, ret['locality_output'] = compute_multimodal_hf_edit_quality_demo(model, locality, tok)
443493

444494
if 'multimodal_locality_prompt' in record.keys():
445495
m_loc_prompt = record["multimodal_locality_prompt"]
446496
m_loc_ground_truth = record["multimodal_locality_ground_truth"]
447497
m_loc_image = record["multimodal_locality_image"]
448498
m_locality = prepare_multimodal_hf_edit(hparams, tok, m_loc_ground_truth, m_loc_prompt, m_loc_image, file_type="image")
449-
_, ret['multimodal_locality_output'] = compute_multimodal_hf_edit_quality(model, m_locality, tok)
499+
# _, ret['multimodal_locality_output'] = compute_multimodal_hf_edit_quality(model, m_locality, tok)
500+
_, _, ret['multimodal_locality_output'] = compute_multimodal_hf_edit_quality_demo(model, m_locality, tok)
450501

451502
return ret
452503

easyeditor/models/wise/WISE.py

Lines changed: 62 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ def euc(query, key, config, act_mask=None, infer=False):
3232

3333
act_fn = ACT2FN[config.hidden_act]
3434
l2_norm = torch.norm(act_fn(key) - act_fn(query), dim=-1)
35+
if l2_norm.dim() == 1:
36+
l2_norm = l2_norm.unsqueeze(0)
3537
if infer and l2_norm.size(1) > 100:
3638
topk = torch.topk(l2_norm, k=1, largest=True)
3739
return topk.values.mean()
@@ -74,7 +76,20 @@ def __init__(self, config, model, device):
7476
self.layer_name = self.layer.rsplit(".", 1)[-1]
7577
adapter_layer = getattr(self.edit_module, self.layer_name)
7678

77-
if type(adapter_layer) is not WISEAdapter:
79+
# if the condition below is True, then it is single-edit
80+
if not config.sequential_edit:
81+
# if type(adapter_layer) is not WISEAdapter:
82+
# 如果 adapter_layer 已经是 WISEAdapter,提取其原始层
83+
if type(adapter_layer) is WISEAdapter:
84+
# 使用 original_layer 作为基础层(这是保存的原始层副本)
85+
base_layer = adapter_layer.original_layer
86+
else:
87+
base_layer = adapter_layer
88+
89+
setattr(self.edit_module, self.layer_name, WISEAdapter(config, base_layer, transpose=transpose))
90+
self.original_layer = copy.deepcopy(base_layer)
91+
print(f"New weights successfully inserted into {layer}")
92+
elif type(adapter_layer) is not WISEAdapter:
7893
setattr(self.edit_module, self.layer_name, WISEAdapter(config, adapter_layer, transpose=transpose))
7994
self.original_layer = copy.deepcopy(adapter_layer)
8095
print(f"New weights successfully inserted into {layer}")
@@ -84,16 +99,27 @@ def __init__(self, config, model, device):
8499
gc.collect()
85100

86101
# Forward
87-
def __call__(self, **kwargs):
102+
def __call__(self, *args, **kwargs):
88103
if not self.config.retrieve:
89-
if hasattr(self.get_adapter_layer(), 'editing') and not self.get_adapter_layer().editing:
90-
# final merge
91-
if not self.get_adapter_layer().original_layer.weight.equal(self.get_adapter_layer().new_weight) and self.get_adapter_layer().editing_total_cnt >= self.config.save_freq:
92-
self.get_adapter_layer().memory_weight.append(self.get_adapter_layer().new_weight)
93-
if len(self.get_adapter_layer().memory_weight) > 0 and self.get_adapter_layer().editing_total_cnt >= self.config.save_freq:
94-
print('length of memory is ', len(self.get_adapter_layer().memory_weight), '!!!!!!')
95-
self.get_adapter_layer().merge_weight()
96-
return self.model(**kwargs)
104+
adapter = self.get_adapter_layer()
105+
if hasattr(adapter, 'editing') and not adapter.editing:
106+
if (not adapter.original_layer.weight.equal(adapter.new_weight)
107+
and adapter.editing_total_cnt >= self.config.save_freq):
108+
adapter.memory_weight.append(adapter.new_weight)
109+
110+
if len(adapter.memory_weight) > 0 and adapter.editing_total_cnt >= self.config.save_freq:
111+
print('length of memory is ', len(adapter.memory_weight), '!!!!!!')
112+
adapter.merge_weight()
113+
# 1. 如果用户传入 model(batch)
114+
if len(args) == 1 and isinstance(args[0], dict):
115+
return self.model(args[0])
116+
# 2. 如果用户传入 model(batch=batch)
117+
elif "batch" in kwargs and isinstance(kwargs["batch"], dict):
118+
batch = kwargs.pop("batch")
119+
return self.model(**batch, **kwargs)
120+
# 3. 普通 HuggingFace 风格,如 model(input_ids=..., pixel_values=...)
121+
else:
122+
return self.model(**kwargs)
97123

98124
def reset_layer(self):
99125
layer = getattr(self.edit_module, self.layer_name)
@@ -257,7 +283,14 @@ def _cal_activation_loss(self, original_layer_output, new_weight_layer_output, c
257283
else:
258284
k = 1
259285
total_loss = []
286+
if self.config.model_name == "blip2":
287+
original_layer_output = original_layer_output.reshape(2, -1, original_layer_output.size(-1))
288+
new_weight_layer_output = new_weight_layer_output.reshape(2, -1, new_weight_layer_output.size(-1))
260289
len_temp = original_layer_output.shape[0] / k - 1
290+
# if len_temp == 0:
291+
# len_temp = 1
292+
# print(len_temp)
293+
# print(act_mask)
261294
for i,act_mk in enumerate(act_mask):
262295
if act_mk is not None:
263296
in_scope_dist = euc(original_layer_output[int(i*len_temp):int((i+1)*len_temp), ...], new_weight_layer_output[int(i*len_temp):int((i+1)*len_temp), ...], config,
@@ -270,7 +303,8 @@ def _cal_activation_loss(self, original_layer_output, new_weight_layer_output, c
270303
out_scope_dist = euc(original_layer_output[int(i-k):, ...], new_weight_layer_output[int(i-k):, ...], config)
271304
else:
272305
out_scope_dist = euc(original_layer_output[int(i-k):int(i+1-k), ...], new_weight_layer_output[int(i-k):int(i+1-k), ...], config)
273-
306+
# print("in_scope_dist: ", in_scope_dist)
307+
# print("out_scope_dist: ", out_scope_dist)
274308
loss = out_scope_dist.view(-1,1) - in_scope_dist + config.gamma
275309
loss2 = out_scope_dist - config.alpha
276310
loss3 = config.beta - in_scope_dist
@@ -539,8 +573,12 @@ def edit(self, config, multimodal_inputs, text_tokens, ans_token_len, act_mask=N
539573
ft_loss = self._cal_ft_loss(multimodal_inputs, text_tokens, last_prompt_token_loc, ans_token_len)
540574

541575
act_loss = super()._cal_activation_loss(super().get_adapter_layer().original_layer_output, super().get_adapter_layer().new_weight_layer_output,
542-
config=config, act_mask=act_mask, deact_mask=deact_mask)
576+
config=config, act_mask=act_mask, deact_mask=deact_mask)
543577
loss = ft_loss + act_loss.to(ft_loss.device)
578+
# if self.config.model_name == "blip2":
579+
# print(self.model.generate(multimodal_inputs[0]))
580+
# elif self.config.model_name == "minigpt4":
581+
# print(self.model.predict_answers(multimodal_inputs))
544582

545583
if loss_meter.stop():
546584
super().get_adapter_layer().save_editing_activation() # add last gradient
@@ -632,10 +670,17 @@ def _cal_ft_loss(self, multimodal_inputs, text_tokens, last_prompt_token_loc, an
632670
if k != 1:
633671
raise AssertionError("Not support Batch Edit")
634672

635-
bs = text_tokens["input_ids"].shape[0] - k
636-
logits = self.model(**multimodal_inputs).logits
637-
shift_logits = logits[:-k, :-1, :].contiguous()
638-
shift_labels = multimodal_inputs['input_ids'][:-k, 1:].contiguous()
673+
if self.config.model_name == "blip2" or self.config.model_name == "minigpt4":
674+
logits = self.model(multimodal_inputs).logits
675+
labels = text_tokens["labels"]
676+
shift_labels = labels[:, 1:].contiguous()
677+
shift_logits = logits[:-k, :-1, :].contiguous()
678+
bs = text_tokens["labels"].shape[0]
679+
else:
680+
logits = self.model(**multimodal_inputs).logits
681+
shift_labels = multimodal_inputs['input_ids'][:-k, 1:].contiguous()
682+
shift_logits = logits[:-k, :-1, :].contiguous()
683+
bs = text_tokens["input_ids"].shape[0] - k
639684
# only cal loss of target text tokens
640685
loss_fct = CrossEntropyLoss(reduction='none')
641686
a = shift_logits.view(-1, shift_logits.size(-1))
@@ -645,5 +690,4 @@ def _cal_ft_loss(self, multimodal_inputs, text_tokens, last_prompt_token_loc, an
645690
loss = loss.view(bs, -1)
646691
label_mask = torch.ones_like(loss, dtype=torch.bool)
647692
ft_loss = ((loss * label_mask).sum(1) / label_mask.sum(1)).mean()
648-
return ft_loss
649-
693+
return ft_loss

0 commit comments

Comments
 (0)