Skip to content

Commit

Permalink
initiate pylint effort
Browse files Browse the repository at this point in the history
  • Loading branch information
frankaging committed Jan 11, 2024
1 parent 2bb665a commit fc5ff91
Show file tree
Hide file tree
Showing 47 changed files with 5,845 additions and 4,047 deletions.
204 changes: 123 additions & 81 deletions data_generators/causal_model.py

Large diffs are not rendered by default.

60 changes: 39 additions & 21 deletions models/basic_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
import os, copy, torch, random, importlib
"""
Basic Utils
"""
import os
import random
import importlib
import torch

from torch import nn
import numpy as np

Expand All @@ -13,7 +20,7 @@ def get_type_from_string(type_str):
type_str = type_str.replace("<class '", "").replace("'>", "")

# Split the string into module and class name
module_name, class_name = type_str.rsplit('.', 1)
module_name, class_name = type_str.rsplit(".", 1)

# Import the module
module = importlib.import_module(module_name)
Expand All @@ -32,7 +39,7 @@ def create_directory(path):
else:
print(f"Directory '{path}' already exists.")


def embed_to_distrib(model, embed, log=False, logits=False):
"""Convert an embedding to a distribution over the vocabulary"""
if "gpt2" in model.config.architectures[0].lower():
Expand All @@ -44,7 +51,7 @@ def embed_to_distrib(model, embed, log=False, logits=False):
elif "llama" in model.config.architectures[0].lower():
assert False, "Support for LLaMA is not here yet"


def set_seed(seed: int):
"""Set seed. Deprecate soon since it is in the huggingface library"""
random.seed(seed)
Expand All @@ -55,16 +62,28 @@ def set_seed(seed: int):

def sigmoid_boundary(_input, boundary_x, boundary_y, temperature):
"""Generate sigmoid mask"""
return torch.sigmoid((_input - boundary_x) / temperature) * \
torch.sigmoid((boundary_y - _input) / temperature)
return torch.sigmoid((_input - boundary_x) / temperature) * torch.sigmoid(
(boundary_y - _input) / temperature
)


def harmonic_sigmoid_boundary(_input, boundary_x, boundary_y, temperature):
"""Generate harmonic sigmoid mask"""
return (_input<=boundary_x)*torch.sigmoid((_input - boundary_x) / temperature) + \
(_input>=boundary_y)*torch.sigmoid((boundary_y - _input) / temperature) + \
((_input>boundary_x)&(_input<boundary_y))*torch.sigmoid(
(0.5 * (torch.abs(_input - boundary_x)**(-1) + torch.abs(_input - boundary_y)**(-1)))**(-1) / temperature
return (
(_input <= boundary_x) * torch.sigmoid((_input - boundary_x) / temperature)
+ (_input >= boundary_y) * torch.sigmoid((boundary_y - _input) / temperature)
+ ((_input > boundary_x) & (_input < boundary_y))
* torch.sigmoid(
(
0.5
* (
torch.abs(_input - boundary_x) ** (-1)
+ torch.abs(_input - boundary_y) ** (-1)
)
)
** (-1)
/ temperature
)
)


Expand All @@ -75,19 +94,19 @@ def count_parameters(model):

def random_permutation_matrix(n):
"""Generate a random permutation matrix"""
P = torch.eye(n)
_p = torch.eye(n)
perm = torch.randperm(n)
P = P[perm]
return P
_p = _p[perm]

return _p


def closeness_to_permutation_loss(R):
def closeness_to_permutation_loss(rotation):
"""Measure how close a rotation m is close to a permutation m"""
row_sum_diff = torch.abs(R.sum(dim=1) - 1.0).mean()
col_sum_diff = torch.abs(R.sum(dim=0) - 1.0).mean()
entry_diff = (R * (1 - R)).mean()
loss = .5 * (row_sum_diff + col_sum_diff) + entry_diff
row_sum_diff = torch.abs(rotation.sum(dim=1) - 1.0).mean()
col_sum_diff = torch.abs(rotation.sum(dim=0) - 1.0).mean()
entry_diff = (rotation * (1 - rotation)).mean()
loss = 0.5 * (row_sum_diff + col_sum_diff) + entry_diff
return loss


Expand All @@ -99,7 +118,6 @@ def format_token(tokenizer, tok):
def top_vals(tokenizer, res, n=10):
"""Pretty print the top n values of a distribution over the vocabulary"""
top_values, top_indices = torch.topk(res, n)
for i in range(len(top_values)):
for i, _ in enumerate(top_values):
tok = format_token(tokenizer, top_indices[i].item())
print(f"{tok:<20} {top_values[i].item()}")

132 changes: 71 additions & 61 deletions models/blip/modelings_blip.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,74 +4,84 @@
from transformers.utils import ModelOutput
from typing import Optional, Union, Tuple, Dict


class BlipWrapper(nn.Module):
def __init__(self, model: BlipForQuestionAnswering):
super(BlipWrapper, self).__init__()
self.model_vis = model.vision_model
self.model_text_enc = model.text_encoder
self.model_text_dec = model.text_decoder
self.decoder_pad_token_id = model.decoder_pad_token_id
self.decoder_start_token_id = model.decoder_start_token_id
self.config = model.config
self.eos_token_id = model.config.text_config.sep_token_id,
self.pad_token_id = model.config.text_config.pad_token_id
self.output_attentions = model.config.output_attentions
self.use_return_dict = model.config.use_return_dict
self.output_hidden_states = model.config.output_hidden_states

def forward(
self,
input_ids: torch.LongTensor,
pixel_values: torch.FloatTensor,
attention_mask: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, Dict]:
super(BlipWrapper, self).__init__()
self.model_vis = model.vision_model
self.model_text_enc = model.text_encoder
self.model_text_dec = model.text_decoder
self.decoder_pad_token_id = model.decoder_pad_token_id
self.decoder_start_token_id = model.decoder_start_token_id
self.config = model.config
self.eos_token_id = (model.config.text_config.sep_token_id,)
self.pad_token_id = model.config.text_config.pad_token_id
self.output_attentions = model.config.output_attentions
self.use_return_dict = model.config.use_return_dict
self.output_hidden_states = model.config.output_hidden_states

return_dict = return_dict if return_dict is not None else self.use_return_dict
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.output_hidden_states
)
def forward(
self,
input_ids: torch.LongTensor,
pixel_values: torch.FloatTensor,
attention_mask: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, Dict]:
return_dict = return_dict if return_dict is not None else self.use_return_dict
output_attentions = (
output_attentions
if output_attentions is not None
else self.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.output_hidden_states
)

vision_outputs = self.model_vis(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
vision_outputs = self.model_vis(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)

image_embeds = vision_outputs[0].to(self.model_text_enc.device)
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long)
image_embeds = vision_outputs[0].to(self.model_text_enc.device)
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long)

input_ids = input_ids.to(self.model_text_enc.device)
question_embeds = self.model_text_enc(
input_ids=input_ids,
attention_mask=attention_mask,
encoder_hidden_states=image_embeds,
encoder_attention_mask=image_attention_mask,
output_hidden_states=True,
)
input_ids = input_ids.to(self.model_text_enc.device)
question_embeds = self.model_text_enc(
input_ids=input_ids,
attention_mask=attention_mask,
encoder_hidden_states=image_embeds,
encoder_attention_mask=image_attention_mask,
output_hidden_states=True,
)

question_embeds_w = question_embeds[0] if not return_dict else question_embeds.last_hidden_state
question_embeds_w = (
question_embeds[0] if not return_dict else question_embeds.last_hidden_state
)

bos_ids = torch.full(
(question_embeds_w.size(0), 1), fill_value=self.decoder_start_token_id, device=self.model_text_enc.device
)
bos_ids = torch.full(
(question_embeds_w.size(0), 1),
fill_value=self.decoder_start_token_id,
device=self.model_text_enc.device,
)

answer_output = self.model_text_dec(
input_ids=bos_ids,
encoder_hidden_states=question_embeds_w,
encoder_attention_mask=attention_mask,
output_hidden_states=True,
reduction="mean"
)
answer_output = self.model_text_dec(
input_ids=bos_ids,
encoder_hidden_states=question_embeds_w,
encoder_attention_mask=attention_mask,
output_hidden_states=True,
reduction="mean",
)

return {
'decoder_logits': answer_output.logits,
'image_embeds': image_embeds,
'encoder_last_hidden_state': question_embeds.last_hidden_state,
'encoder_hidden_states': question_embeds.hidden_states,
'decoder_hidden_states': answer_output.hidden_states,
}
return {
"decoder_logits": answer_output.logits,
"image_embeds": image_embeds,
"encoder_last_hidden_state": question_embeds.last_hidden_state,
"encoder_hidden_states": question_embeds.hidden_states,
"decoder_hidden_states": answer_output.hidden_states,
}
70 changes: 42 additions & 28 deletions models/blip/modelings_intervenable_blip.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,23 @@
# 'vis.attention_value_output': ("vision_model.encoder.layers[%s].self_attn.projection", CONST_INPUT_HOOK),
# 'vis.attention_output': ("vision_model.encoder.layers[%s].self_attn", CONST_OUTPUT_HOOK),
# 'vis.attention_input': ("vision_model.encoder.layers[%s].self_attn", CONST_INPUT_HOOK),

'block_input': ("text_encoder.encoder.layer[%s]", CONST_INPUT_HOOK),
'block_output': ("text_encoder.encoder.layer[%s]", CONST_INPUT_HOOK),
'mlp_activation': ("text_encoder.encoder.layer[%s].intermediate.dense", CONST_OUTPUT_HOOK),
'mlp_output': ("text_encoder.encoder.layer[%s].output", CONST_OUTPUT_HOOK),
'mlp_input': ("text_encoder.encoder.layer[%s].intermediate", CONST_INPUT_HOOK),
'attention_value_output': ("text_encoder.encoder.layer[%s].attention.output.dense", CONST_INPUT_HOOK),
'attention_output': ("text_encoder.encoder.layer[%s].attention.output", CONST_OUTPUT_HOOK),
'attention_input': ("text_encoder.encoder.layer[%s].attention", CONST_INPUT_HOOK),

"block_input": ("text_encoder.encoder.layer[%s]", CONST_INPUT_HOOK),
"block_output": ("text_encoder.encoder.layer[%s]", CONST_INPUT_HOOK),
"mlp_activation": (
"text_encoder.encoder.layer[%s].intermediate.dense",
CONST_OUTPUT_HOOK,
),
"mlp_output": ("text_encoder.encoder.layer[%s].output", CONST_OUTPUT_HOOK),
"mlp_input": ("text_encoder.encoder.layer[%s].intermediate", CONST_INPUT_HOOK),
"attention_value_output": (
"text_encoder.encoder.layer[%s].attention.output.dense",
CONST_INPUT_HOOK,
),
"attention_output": (
"text_encoder.encoder.layer[%s].attention.output",
CONST_OUTPUT_HOOK,
),
"attention_input": ("text_encoder.encoder.layer[%s].attention", CONST_INPUT_HOOK),
# 'block_input': ("text_decoder.bert.encoder.layer[%s]", CONST_INPUT_HOOK),
# 'block_output': ("text_decoder.bert.encoder.layer[%s]", CONST_INPUT_HOOK),
# 'mlp_activation': ("text_decoder.bert.encoder.layer[%s].intermediate.dense", CONST_OUTPUT_HOOK),
Expand All @@ -46,15 +53,14 @@


blip_type_to_dimension_mapping = {
# 'vis.block_input': ("image_text_hidden_size", ),
# 'vis.block_output': ("image_text_hidden_size", ),
# 'vis.block_input': ("image_text_hidden_size", ),
# 'vis.block_output': ("image_text_hidden_size", ),
# 'vis.mlp_activation': ("projection_dim", ),
# 'vis.mlp_output': ("image_text_hidden_size", ),
# 'vis.mlp_input': ("image_text_hidden_size", ),
# 'vis.attention_value_output': ("image_text_hidden_size/text_config.num_attention_heads", ),
# 'vis.attention_output': ("image_text_hidden_size", ),
# 'vis.attention_input': ("image_text_hidden_size", ),

# 'lang.block_input': ("image_text_hidden_size", ),
# 'lang.block_output': ("image_text_hidden_size", ),
# 'lang.mlp_activation': ("projection_dim", ),
Expand All @@ -63,25 +69,31 @@
# 'lang.attention_value_output': ("image_text_hidden_size/text_config.num_attention_heads", ),
# 'lang.attention_output': ("image_text_hidden_size", ),
# 'lang.attention_input': ("image_text_hidden_size", ),

'block_input': ("image_text_hidden_size", ),
'block_output': ("image_text_hidden_size", ),
'mlp_activation': ("projection_dim", ),
'mlp_output': ("image_text_hidden_size", ),
'mlp_input': ("image_text_hidden_size", ),
'attention_value_output': ("image_text_hidden_size/text_config.num_attention_heads", ),
'attention_output': ("image_text_hidden_size", ),
'attention_input': ("image_text_hidden_size", ),
'cross_attention_value_output': ("image_text_hidden_size/text_config.num_attention_heads", ),
'cross_attention_output': ("image_text_hidden_size", ),
'cross_attention_input': ("image_text_hidden_size", ),
"block_input": ("image_text_hidden_size",),
"block_output": ("image_text_hidden_size",),
"mlp_activation": ("projection_dim",),
"mlp_output": ("image_text_hidden_size",),
"mlp_input": ("image_text_hidden_size",),
"attention_value_output": (
"image_text_hidden_size/text_config.num_attention_heads",
),
"attention_output": ("image_text_hidden_size",),
"attention_input": ("image_text_hidden_size",),
"cross_attention_value_output": (
"image_text_hidden_size/text_config.num_attention_heads",
),
"cross_attention_output": ("image_text_hidden_size",),
"cross_attention_input": ("image_text_hidden_size",),
}


"""blip model with wrapper"""
blip_wrapper_type_to_module_mapping = {}
for k, v in blip_type_to_module_mapping.items():
blip_wrapper_type_to_module_mapping[k] = (v[0].replace('text_encoder', 'model_text_enc'), v[1])
blip_wrapper_type_to_module_mapping[k] = (
v[0].replace("text_encoder", "model_text_enc"),
v[1],
)


blip_wrapper_type_to_dimension_mapping = blip_type_to_dimension_mapping
Expand All @@ -93,6 +105,8 @@ def create_blip(name="Salesforce/blip-vqa-base", cache_dir="../../.huggingface_c

config = BlipConfig.from_pretrained(name)
processor = BlipProcessor.from_pretrained(name)
blip = BlipForQuestionAnswering.from_pretrained(name, config=config, cache_dir=cache_dir)
blip = BlipForQuestionAnswering.from_pretrained(
name, config=config, cache_dir=cache_dir
)
print("loaded model")
return config, processor, blip
return config, processor, blip
Loading

0 comments on commit fc5ff91

Please sign in to comment.