From c63d97640b9bee95d496548cbb32144e0c0d3870 Mon Sep 17 00:00:00 2001 From: PinetreePantry Date: Mon, 26 Feb 2024 23:15:59 -0800 Subject: [PATCH 01/16] Update transformers version --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index cc61eb7c..f90b1de5 100755 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ torch>=2.0.0 -transformers==4.37.0 +transformers==4.38.1 datasets==2.16.1 protobuf==3.20.* matplotlib==3.7.4 From 0e4003cdd3bcb3ee0b746883ca684ca9bea62c61 Mon Sep 17 00:00:00 2001 From: frankaging Date: Mon, 8 Apr 2024 14:32:18 -0700 Subject: [PATCH 02/16] removing seaborn dependency --- requirements.txt | 1 - tutorials/advanced_tutorials/Boundless_DAS.ipynb | 3 +-- tutorials/advanced_tutorials/tutorial_price_tagging_utils.py | 1 - .../basic_tutorials/tutorial_intervention_training_utils.py | 1 - 4 files changed, 1 insertion(+), 5 deletions(-) diff --git a/requirements.txt b/requirements.txt index b82f31cf..0772e829 100755 --- a/requirements.txt +++ b/requirements.txt @@ -3,7 +3,6 @@ transformers>=4.38.1 datasets>=2.16.1 protobuf>=3.20.0 matplotlib>=3.7.4 -seaborn>=0.13.1 ipywidgets>=8.1.1 plotnine>=0.12.4 huggingface-hub==0.20.3 diff --git a/tutorials/advanced_tutorials/Boundless_DAS.ipynb b/tutorials/advanced_tutorials/Boundless_DAS.ipynb index 605fd0c4..9119e175 100644 --- a/tutorials/advanced_tutorials/Boundless_DAS.ipynb +++ b/tutorials/advanced_tutorials/Boundless_DAS.ipynb @@ -69,7 +69,6 @@ "outputs": [], "source": [ "import torch\n", - "import seaborn as sns\n", "from tqdm import tqdm, trange\n", "from datasets import Dataset\n", "from torch.utils.data import DataLoader\n", @@ -617,7 +616,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.12" + "version": "3.10.13" } }, "nbformat": 4, diff --git a/tutorials/advanced_tutorials/tutorial_price_tagging_utils.py b/tutorials/advanced_tutorials/tutorial_price_tagging_utils.py index 0ec7c6a0..c0bada0c 100644 --- a/tutorials/advanced_tutorials/tutorial_price_tagging_utils.py +++ b/tutorials/advanced_tutorials/tutorial_price_tagging_utils.py @@ -1,7 +1,6 @@ import itertools import matplotlib.pyplot as plt import numpy as np -import seaborn as sns from functools import partial from typing import Dict, Optional, Sequence from torch.nn import functional as F diff --git a/tutorials/basic_tutorials/tutorial_intervention_training_utils.py b/tutorials/basic_tutorials/tutorial_intervention_training_utils.py index 015cfe7c..d30005a5 100644 --- a/tutorials/basic_tutorials/tutorial_intervention_training_utils.py +++ b/tutorials/basic_tutorials/tutorial_intervention_training_utils.py @@ -1,7 +1,6 @@ import itertools import matplotlib.pyplot as plt import numpy as np -import seaborn as sns from functools import partial from typing import Dict, Optional, Sequence from torch.nn import functional as F From 6668b1844220467c74d8500174d3670faaaa0b67 Mon Sep 17 00:00:00 2001 From: Zen Date: Mon, 8 Apr 2024 15:23:26 -0700 Subject: [PATCH 03/16] Update setup.py --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 4d5929af..25937cca 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( name="pyvene", - version="0.1.0", + version="0.1.1", description="Use Activation Intervention to Interpret Causal Mechanism of Model", long_description=long_description, long_description_content_type='text/markdown', From 6340a21b99f1a53ebc8ba0e650723058ce3efac3 Mon Sep 17 00:00:00 2001 From: Aryaman Arora Date: Mon, 8 Apr 2024 15:38:47 -0700 Subject: [PATCH 04/16] store str in config only --- pyvene/models/intervenable_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyvene/models/intervenable_base.py b/pyvene/models/intervenable_base.py index f1426648..62ba64be 100644 --- a/pyvene/models/intervenable_base.py +++ b/pyvene/models/intervenable_base.py @@ -47,7 +47,7 @@ def __init__(self, config, model, **kwargs): self.mode = config.mode intervention_type = config.intervention_types self.is_model_stateless = is_stateless(model) - self.config.model_type = type(model) # backfill + self.config.model_type = str(type(model)) # backfill self.use_fast = kwargs["use_fast"] if "use_fast" in kwargs else False self.model_has_grad = False From 203103b6cfb86629b83def3cb3dbd913264db709 Mon Sep 17 00:00:00 2001 From: Zen Date: Mon, 8 Apr 2024 20:59:23 -0700 Subject: [PATCH 05/16] Update README.md --- README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 0ebb1375..5e75c8b3 100644 --- a/README.md +++ b/README.md @@ -274,4 +274,5 @@ If you would like to read more works on this area, here is a list of papers that ## Star History -[![Star History Chart](https://api.star-history.com/svg?repos=stanfordnlp/pyvene&type=Date)](https://star-history.com/#stanfordnlp/pyvene&Date) +[![Star History Chart](https://api.star-history.com/svg?repos=stanfordnlp/pyvene,stanfordnlp/pyreft&type=Date)](https://star-history.com/#stanfordnlp/pyvene&stanfordnlp/pyreft&Date) + From 46712a2bc4668c0d2a80e51cfa33a1ee05252905 Mon Sep 17 00:00:00 2001 From: Aryaman Arora Date: Tue, 23 Apr 2024 12:58:30 -0700 Subject: [PATCH 06/16] +gemma, +llama classifier --- pyvene/models/gemma/__init__.py | 0 .../gemma/modelings_intervenable_gemma.py | 87 +++++++++++++++++++ pyvene/models/intervenable_modelcard.py | 9 ++ .../llama/modelings_intervenable_llama.py | 9 ++ 4 files changed, 105 insertions(+) create mode 100644 pyvene/models/gemma/__init__.py create mode 100644 pyvene/models/gemma/modelings_intervenable_gemma.py diff --git a/pyvene/models/gemma/__init__.py b/pyvene/models/gemma/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pyvene/models/gemma/modelings_intervenable_gemma.py b/pyvene/models/gemma/modelings_intervenable_gemma.py new file mode 100644 index 00000000..3fd2207c --- /dev/null +++ b/pyvene/models/gemma/modelings_intervenable_gemma.py @@ -0,0 +1,87 @@ +""" +Each modeling file in this library is a mapping between +abstract naming of intervention anchor points and actual +model module defined in the huggingface library. + +We also want to let the intervention library know how to +config the dimensions of intervention based on model config +defined in the huggingface library. +""" + + +import torch +from ..constants import * + + +gemma_type_to_module_mapping = { + "block_input": ("layers[%s]", CONST_INPUT_HOOK), + "block_output": ("layers[%s]", CONST_OUTPUT_HOOK), + "mlp_activation": ("layers[%s].mlp.act_fn", CONST_OUTPUT_HOOK), + "mlp_output": ("layers[%s].mlp", CONST_OUTPUT_HOOK), + "mlp_input": ("layers[%s].mlp", CONST_INPUT_HOOK), + "attention_value_output": ("layers[%s].self_attn.o_proj", CONST_INPUT_HOOK), + "head_attention_value_output": ("layers[%s].self_attn.o_proj", CONST_INPUT_HOOK), + "attention_output": ("layers[%s].self_attn", CONST_OUTPUT_HOOK), + "attention_input": ("layers[%s].self_attn", CONST_INPUT_HOOK), + "query_output": ("layers[%s].self_attn.q_proj", CONST_OUTPUT_HOOK), + "key_output": ("layers[%s].self_attn.k_proj", CONST_OUTPUT_HOOK), + "value_output": ("layers[%s].self_attn.v_proj", CONST_OUTPUT_HOOK), + "head_query_output": ("layers[%s].self_attn.q_proj", CONST_OUTPUT_HOOK), + "head_key_output": ("layers[%s].self_attn.k_proj", CONST_OUTPUT_HOOK), + "head_value_output": ("layers[%s].self_attn.v_proj", CONST_OUTPUT_HOOK), +} + + +gemma_type_to_dimension_mapping = { + "block_input": ("hidden_size",), + "block_output": ("hidden_size",), + "mlp_activation": ("intermediate_size",), + "mlp_output": ("hidden_size",), + "mlp_input": ("hidden_size",), + "attention_value_output": ("hidden_size",), + "head_attention_value_output": ("hidden_size/num_attention_heads",), + "attention_output": ("hidden_size",), + "attention_input": ("hidden_size",), + "query_output": ("hidden_size",), + "key_output": ("hidden_size",), + "value_output": ("hidden_size",), + "head_query_output": ("hidden_size/num_attention_heads",), + "head_key_output": ("hidden_size/num_attention_heads",), + "head_value_output": ("hidden_size/num_attention_heads",), +} + + +"""gemma model with LM head""" +gemma_lm_type_to_module_mapping = {} +for k, v in gemma_type_to_module_mapping.items(): + gemma_lm_type_to_module_mapping[k] = (f"model.{v[0]}", v[1]) + + +gemma_lm_type_to_dimension_mapping = gemma_type_to_dimension_mapping + + +"""gemma model with classifier head""" +gemma_classifier_type_to_module_mapping = {} +for k, v in gemma_type_to_module_mapping.items(): + gemma_classifier_type_to_module_mapping[k] = (f"model.{v[0]}", v[1]) + + +gemma_classifier_type_to_dimension_mapping = gemma_type_to_dimension_mapping + + +def create_gemma( + name="sharpbai/alpaca-7b-merged", cache_dir=None, dtype=torch.bfloat16 +): + """Creates a Gemma Causal LM model, config, and tokenizer from the given name and revision""" + from transformers import GemmaForCausalLM, GemmaTokenizer, GemmaConfig + + config = GemmaConfig.from_pretrained(name, cache_dir=cache_dir) + tokenizer = GemmaTokenizer.from_pretrained(name, cache_dir=cache_dir) + gemma = GemmaForCausalLM.from_pretrained( + name, + config=config, + cache_dir=cache_dir, + torch_dtype=dtype, # save memory + ) + print("loaded model") + return config, tokenizer, gemma diff --git a/pyvene/models/intervenable_modelcard.py b/pyvene/models/intervenable_modelcard.py index 42013ba6..a445bfc6 100644 --- a/pyvene/models/intervenable_modelcard.py +++ b/pyvene/models/intervenable_modelcard.py @@ -1,6 +1,7 @@ from .constants import * from .llama.modelings_intervenable_llama import * from .mistral.modellings_intervenable_mistral import * +from .gemma.modelings_intervenable_gemma import * from .gpt2.modelings_intervenable_gpt2 import * from .gpt_neo.modelings_intervenable_gpt_neo import * from .gpt_neox.modelings_intervenable_gpt_neox import * @@ -39,12 +40,16 @@ hf_models.gpt2.modeling_gpt2.GPT2ForSequenceClassification: gpt2_classifier_type_to_module_mapping, hf_models.llama.modeling_llama.LlamaModel: llama_type_to_module_mapping, hf_models.llama.modeling_llama.LlamaForCausalLM: llama_lm_type_to_module_mapping, + hf_models.llama.modeling_llama.LlamaForSequenceClassification: llama_classifier_type_to_module_mapping, hf_models.gpt_neo.modeling_gpt_neo.GPTNeoModel: gpt_neo_type_to_module_mapping, hf_models.gpt_neo.modeling_gpt_neo.GPTNeoForCausalLM: gpt_neo_lm_type_to_module_mapping, hf_models.gpt_neox.modeling_gpt_neox.GPTNeoXModel: gpt_neox_type_to_module_mapping, hf_models.gpt_neox.modeling_gpt_neox.GPTNeoXForCausalLM: gpt_neox_lm_type_to_module_mapping, hf_models.mistral.modeling_mistral.MistralModel: mistral_type_to_module_mapping, hf_models.mistral.modeling_mistral.MistralForCausalLM: mistral_lm_type_to_module_mapping, + hf_models.gemma.modeling_gemma.GemmaModel: gemma_type_to_module_mapping, + hf_models.gemma.modeling_gemma.GemmaForCausalLM: gemma_lm_type_to_module_mapping, + hf_models.gemma.modeling_gemma.GemmaForSequenceClassification: gemma_classifier_type_to_module_mapping, hf_models.blip.modeling_blip.BlipForQuestionAnswering: blip_type_to_module_mapping, hf_models.blip.modeling_blip.BlipForImageTextRetrieval: blip_itm_type_to_module_mapping, BlipWrapper: blip_wrapper_type_to_module_mapping, @@ -65,12 +70,16 @@ hf_models.gpt2.modeling_gpt2.GPT2ForSequenceClassification: gpt2_classifier_type_to_dimension_mapping, hf_models.llama.modeling_llama.LlamaModel: llama_type_to_dimension_mapping, hf_models.llama.modeling_llama.LlamaForCausalLM: llama_lm_type_to_dimension_mapping, + hf_models.llama.modeling_llama.LlamaForSequenceClassification: llama_classifier_type_to_dimension_mapping, hf_models.gpt_neo.modeling_gpt_neo.GPTNeoModel: gpt_neo_type_to_dimension_mapping, hf_models.gpt_neo.modeling_gpt_neo.GPTNeoForCausalLM: gpt_neo_lm_type_to_dimension_mapping, hf_models.gpt_neox.modeling_gpt_neox.GPTNeoXModel: gpt_neox_type_to_dimension_mapping, hf_models.gpt_neox.modeling_gpt_neox.GPTNeoXForCausalLM: gpt_neox_lm_type_to_dimension_mapping, hf_models.mistral.modeling_mistral.MistralModel: mistral_type_to_dimension_mapping, hf_models.mistral.modeling_mistral.MistralForCausalLM: mistral_lm_type_to_dimension_mapping, + hf_models.gemma.modeling_gemma.GemmaModel: gemma_type_to_dimension_mapping, + hf_models.gemma.modeling_gemma.GemmaForCausalLM: gemma_lm_type_to_dimension_mapping, + hf_models.gemma.modeling_gemma.GemmaForSequenceClassification: gemma_classifier_type_to_dimension_mapping, hf_models.blip.modeling_blip.BlipForQuestionAnswering: blip_type_to_dimension_mapping, hf_models.blip.modeling_blip.BlipForImageTextRetrieval: blip_itm_type_to_dimension_mapping, BlipWrapper: blip_wrapper_type_to_dimension_mapping, diff --git a/pyvene/models/llama/modelings_intervenable_llama.py b/pyvene/models/llama/modelings_intervenable_llama.py index ff940b9b..82817b9b 100644 --- a/pyvene/models/llama/modelings_intervenable_llama.py +++ b/pyvene/models/llama/modelings_intervenable_llama.py @@ -60,6 +60,15 @@ llama_lm_type_to_dimension_mapping = llama_type_to_dimension_mapping +"""llama model with classifier head""" +llama_classifier_type_to_module_mapping = {} +for k, v in llama_type_to_module_mapping.items(): + llama_classifier_type_to_module_mapping[k] = (f"model.{v[0]}", v[1]) + + +llama_classifier_type_to_dimension_mapping = llama_type_to_dimension_mapping + + def create_llama( name="sharpbai/alpaca-7b-merged", cache_dir=None, dtype=torch.bfloat16 ): From d57fcfb5b7bb3be95c85f529e2066b96a62aa23a Mon Sep 17 00:00:00 2001 From: Aryaman Arora Date: Tue, 23 Apr 2024 13:00:09 -0700 Subject: [PATCH 07/16] fix typos in mistral --- pyvene/models/mistral/modellings_intervenable_mistral.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pyvene/models/mistral/modellings_intervenable_mistral.py b/pyvene/models/mistral/modellings_intervenable_mistral.py index dbe7118c..edbd6176 100644 --- a/pyvene/models/mistral/modellings_intervenable_mistral.py +++ b/pyvene/models/mistral/modellings_intervenable_mistral.py @@ -51,7 +51,7 @@ } -"""llama model with LM head""" +"""mistral model with LM head""" mistral_lm_type_to_module_mapping = {} for k, v in mistral_type_to_module_mapping.items(): mistral_lm_type_to_module_mapping[k] = (f"model.{v[0]}", v[1]) @@ -68,11 +68,11 @@ def create_mistral( config = AutoConfig.from_pretrained(name, cache_dir=cache_dir) tokenizer = AutoTokenizer.from_pretrained(name, cache_dir=cache_dir) - llama = AutoModelForCausalLM.from_pretrained( + mistral = AutoModelForCausalLM.from_pretrained( name, config=config, cache_dir=cache_dir, torch_dtype=torch.bfloat16, # save memory ) print("loaded model") - return config, tokenizer, llama + return config, tokenizer, mistral From a6fe3054b43d8e63541083851419f97ba30e48ed Mon Sep 17 00:00:00 2001 From: Aryaman Arora Date: Tue, 23 Apr 2024 13:02:11 -0700 Subject: [PATCH 08/16] fix gemma creator --- pyvene/models/gemma/modelings_intervenable_gemma.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyvene/models/gemma/modelings_intervenable_gemma.py b/pyvene/models/gemma/modelings_intervenable_gemma.py index 3fd2207c..09135bb8 100644 --- a/pyvene/models/gemma/modelings_intervenable_gemma.py +++ b/pyvene/models/gemma/modelings_intervenable_gemma.py @@ -70,7 +70,7 @@ def create_gemma( - name="sharpbai/alpaca-7b-merged", cache_dir=None, dtype=torch.bfloat16 + name="google/gemma-2b-it", cache_dir=None, dtype=torch.bfloat16 ): """Creates a Gemma Causal LM model, config, and tokenizer from the given name and revision""" from transformers import GemmaForCausalLM, GemmaTokenizer, GemmaConfig From 59beb564c1a6b7b6aa0f086e6cac682db0c21e77 Mon Sep 17 00:00:00 2001 From: PinetreePantry Date: Tue, 30 Apr 2024 22:38:28 -0700 Subject: [PATCH 09/16] Llava implementation --- pyvene/__init__.py | 1 + pyvene/models/intervenable_modelcard.py | 3 + pyvene/models/llava/__init__.py | 0 .../llava/modelings_intervenable_llava.py | 92 +++++++++++++++++++ 4 files changed, 96 insertions(+) create mode 100644 pyvene/models/llava/__init__.py create mode 100644 pyvene/models/llava/modelings_intervenable_llava.py diff --git a/pyvene/__init__.py b/pyvene/__init__.py index c6c44dbe..69e80a35 100644 --- a/pyvene/__init__.py +++ b/pyvene/__init__.py @@ -41,6 +41,7 @@ from .models.gru.modelings_intervenable_gru import create_gru from .models.gru.modelings_intervenable_gru import create_gru_lm from .models.gru.modelings_intervenable_gru import create_gru_classifier +from .models.llava.modelings_intervenable_llava import create_llava from .models.gru.modelings_gru import GRUConfig from .models.llama.modelings_intervenable_llama import create_llama from .models.mlp.modelings_intervenable_mlp import create_mlp_classifier diff --git a/pyvene/models/intervenable_modelcard.py b/pyvene/models/intervenable_modelcard.py index a445bfc6..f8997850 100644 --- a/pyvene/models/intervenable_modelcard.py +++ b/pyvene/models/intervenable_modelcard.py @@ -10,6 +10,7 @@ from .blip.modelings_intervenable_blip import * from .blip.modelings_intervenable_blip_itm import * from .backpack_gpt2.modelings_intervenable_backpack_gpt2 import * +from .llava.modelings_intervenable_llava import * ######################################################################### @@ -41,6 +42,7 @@ hf_models.llama.modeling_llama.LlamaModel: llama_type_to_module_mapping, hf_models.llama.modeling_llama.LlamaForCausalLM: llama_lm_type_to_module_mapping, hf_models.llama.modeling_llama.LlamaForSequenceClassification: llama_classifier_type_to_module_mapping, + hf_models.llava.modeling_llava.LlavaForConditionalGeneration: llava_type_to_module_mapping, hf_models.gpt_neo.modeling_gpt_neo.GPTNeoModel: gpt_neo_type_to_module_mapping, hf_models.gpt_neo.modeling_gpt_neo.GPTNeoForCausalLM: gpt_neo_lm_type_to_module_mapping, hf_models.gpt_neox.modeling_gpt_neox.GPTNeoXModel: gpt_neox_type_to_module_mapping, @@ -71,6 +73,7 @@ hf_models.llama.modeling_llama.LlamaModel: llama_type_to_dimension_mapping, hf_models.llama.modeling_llama.LlamaForCausalLM: llama_lm_type_to_dimension_mapping, hf_models.llama.modeling_llama.LlamaForSequenceClassification: llama_classifier_type_to_dimension_mapping, + hf_models.llava.modeling_llava.LlavaForConditionalGeneration: llava_type_to_dimension_mapping, hf_models.gpt_neo.modeling_gpt_neo.GPTNeoModel: gpt_neo_type_to_dimension_mapping, hf_models.gpt_neo.modeling_gpt_neo.GPTNeoForCausalLM: gpt_neo_lm_type_to_dimension_mapping, hf_models.gpt_neox.modeling_gpt_neox.GPTNeoXModel: gpt_neox_type_to_dimension_mapping, diff --git a/pyvene/models/llava/__init__.py b/pyvene/models/llava/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pyvene/models/llava/modelings_intervenable_llava.py b/pyvene/models/llava/modelings_intervenable_llava.py new file mode 100644 index 00000000..c37d019b --- /dev/null +++ b/pyvene/models/llava/modelings_intervenable_llava.py @@ -0,0 +1,92 @@ +""" +Each modeling file in this library is a mapping between +abstract naming of intervention anchor points and actual +model module defined in the huggingface library. + +We also want to let the intervention library know how to +config the dimensions of intervention based on model config +defined in the huggingface library. +""" + + +import torch +from ..constants import * + +llava_type_to_module_mapping = { + "block_input": ("language_model.model.layers[%s]", CONST_INPUT_HOOK), + "block_output": ("language_model.model.layers[%s]", CONST_OUTPUT_HOOK), + "mlp_activation": ("language_model.model.layers[%s].mlp.act_fn", CONST_OUTPUT_HOOK), + "mlp_output": ("language_model.model.layers[%s].mlp", CONST_OUTPUT_HOOK), + "mlp_input": ("language_model.model.layers[%s].mlp", CONST_INPUT_HOOK), + "attention_value_output": ("language_model.model.layers[%s].self_attn.o_proj", CONST_INPUT_HOOK), + "head_attention_value_output": ("language_model.model.layers[%s].self_attn.o_proj", CONST_INPUT_HOOK), + "attention_output": ("language_model.model.layers[%s].self_attn", CONST_OUTPUT_HOOK), + "attention_input": ("language_model.model.layers[%s].self_attn", CONST_INPUT_HOOK), + "query_output": ("language_model.model.layers[%s].self_attn.q_proj", CONST_OUTPUT_HOOK), + "key_output": ("language_model.model.layers[%s].self_attn.k_proj", CONST_OUTPUT_HOOK), + "value_output": ("language_model.model.layers[%s].self_attn.v_proj", CONST_OUTPUT_HOOK), + "head_query_output": ("language_model.model.layers[%s].self_attn.q_proj", CONST_OUTPUT_HOOK), + "head_key_output": ("language_model.model.layers[%s].self_attn.k_proj", CONST_OUTPUT_HOOK), + "head_value_output": ("language_model.model.layers[%s].self_attn.v_proj", CONST_OUTPUT_HOOK), +} + + +llava_type_to_dimension_mapping = { + "block_input": ("text_config.hidden_size",), + "block_output": ("text_config.hidden_size",), + "mlp_activation": ("text_config.intermediate_size",), + "mlp_output": ("text_config.hidden_size",), + "mlp_input": ("text_config.hidden_size",), + "attention_value_output": ("text_config.hidden_size",), + "head_attention_value_output": ("text_config.hidden_size/text_config.num_attention_heads",), + "attention_output": ("text_config.hidden_size",), + "attention_input": ("text_config.hidden_size",), + "query_output": ("text_config.hidden_size",), + "key_output": ("text_config.hidden_size",), + "value_output": ("text_config.hidden_size",), + "head_query_output": ("text_config.hidden_size/text_config.num_attention_heads",), + "head_key_output": ("text_config.hidden_size/text_config.num_attention_heads",), + "head_value_output": ("text_config.hidden_size/text_config.num_attention_heads",), +} + + +"""llava model with LM head""" +llava_lm_type_to_module_mapping = {} +for k, v in llava_type_to_module_mapping.items(): + llava_lm_type_to_module_mapping[k] = (f"model.{v[0]}", v[1]) + + +llava_lm_type_to_dimension_mapping = llava_type_to_dimension_mapping + + +"""llava model with classifier head""" +llava_classifier_type_to_module_mapping = {} +for k, v in llava_type_to_module_mapping.items(): + llava_classifier_type_to_module_mapping[k] = (f"model.{v[0]}", v[1]) + + +llava_classifier_type_to_dimension_mapping = llava_type_to_dimension_mapping + + + + +def create_llava( + name="llava-hf/llava-1.5-7b-hf", cache_dir=None, dtype=torch.bfloat16 +): + """Creates a llava Causal LM model, config, and tokenizer from the given name and revision""" + from transformers import LlavaForConditionalGeneration, LlavaConfig, AutoTokenizer, AutoProcessor + + config = LlavaConfig.from_pretrained(name, cache_dir=cache_dir) + tokenizer = AutoTokenizer.from_pretrained(name, use_fast=False) + llava = LlavaForConditionalGeneration.from_pretrained( + name, + config=config, + cache_dir=cache_dir, + torch_dtype=dtype, + ) + + image_processor = AutoProcessor.from_pretrained(name) + + print("loaded model") + return config, tokenizer, llava, image_processor + From 38de8544ac64a523e09aa6a4d0be9e65f33a0844 Mon Sep 17 00:00:00 2001 From: Amir Zur Date: Wed, 1 May 2024 09:36:04 -0700 Subject: [PATCH 10/16] Added `use_cache` flag to intervenable model forward call --- pyvene/models/intervenable_base.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pyvene/models/intervenable_base.py b/pyvene/models/intervenable_base.py index 62ba64be..21eb6509 100644 --- a/pyvene/models/intervenable_base.py +++ b/pyvene/models/intervenable_base.py @@ -1320,6 +1320,7 @@ def forward( labels: Optional[torch.LongTensor] = None, output_original_output: Optional[bool] = False, return_dict: Optional[bool] = None, + use_cache: Optional[bool] = True, ): """ Main forward function that serves a wrapper to @@ -1440,9 +1441,9 @@ def forward( # run intervened forward if labels is not None: - counterfactual_outputs = self.model(**base, labels=labels) + counterfactual_outputs = self.model(**base, labels=labels, use_cache=use_cache) else: - counterfactual_outputs = self.model(**base) + counterfactual_outputs = self.model(**base, use_cache=use_cache) set_handlers_to_remove.remove() self._output_validation() From 4d7eabd16c9e548fa22d882b80216537b8400e6b Mon Sep 17 00:00:00 2001 From: Amir Zur Date: Wed, 1 May 2024 14:09:44 -0700 Subject: [PATCH 11/16] Debugged `use_cache` flag for MLP model (which doesn't take in the `use_cache` argument in its forward call) --- pyvene/models/intervenable_base.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/pyvene/models/intervenable_base.py b/pyvene/models/intervenable_base.py index 21eb6509..68aea2fa 100644 --- a/pyvene/models/intervenable_base.py +++ b/pyvene/models/intervenable_base.py @@ -1440,10 +1440,14 @@ def forward( ) # run intervened forward - if labels is not None: - counterfactual_outputs = self.model(**base, labels=labels, use_cache=use_cache) - else: - counterfactual_outputs = self.model(**base, use_cache=use_cache) + model_kwargs = {} + if labels is not None: # for training + model_kwargs["labels"] = labels + if 'use_cache' in self.model.config.to_dict(): # for transformer models + model_kwargs["use_cache"] = use_cache + + counterfactual_outputs = self.model(**base, **model_kwargs) + set_handlers_to_remove.remove() self._output_validation() From 263f98d0e15751d82d9adb821f0dedf97b4bea5c Mon Sep 17 00:00:00 2001 From: atticusg Date: Fri, 3 May 2024 10:46:56 -0700 Subject: [PATCH 12/16] Create visualization.py and add heatmaps --- pyvene/models/visualization.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) create mode 100644 pyvene/models/visualization.py diff --git a/pyvene/models/visualization.py b/pyvene/models/visualization.py new file mode 100644 index 00000000..0c6b6006 --- /dev/null +++ b/pyvene/models/visualization.py @@ -0,0 +1,31 @@ +import seaborn + +def rotation_token_heatmap(rotate_layer, + tokens, + token_size, + variables, + intervention_sizes): + + W = rotate_layer.weight.data + in_dim, out_dim = W.shape + + assert in_dim % token_size == 0 + assert in_dim / token_size >= len(tokens) + + assert out_dim % intervention_size == 0 + assert out_dim / intervention_size >= len(variables) + + heatmap = [] + for j in range(len(variables)): + row = [] + for i in range(len(tokens)): + row.append(W[i*token_size:(i+1)*token_size, j*intervention_size:(j+1)*intervention_size].sum()) + mean = sum(row) + heatmap.append([x/mean for x in row]) + return seaborn.heatmap(heatmap, + xticklabels=variables, + yticklabels=tokens) + + + + From 68453ff9e5331ceb001302d3388359df75723f5c Mon Sep 17 00:00:00 2001 From: atticusg Date: Sat, 4 May 2024 17:58:15 -0700 Subject: [PATCH 13/16] Update to printing causal models --- pyvene/data_generators/causal_model.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/pyvene/data_generators/causal_model.py b/pyvene/data_generators/causal_model.py index d46307d9..de8a2d93 100644 --- a/pyvene/data_generators/causal_model.py +++ b/pyvene/data_generators/causal_model.py @@ -35,9 +35,6 @@ def __init__( assert variable in self.values assert variable in self.children assert variable in self.functions - assert len(inspect.getfullargspec(self.functions[variable])[0]) == len( - self.parents[variable] - ) if timesteps is not None: assert variable in timesteps for variable2 in copy.copy(self.variables): @@ -79,6 +76,8 @@ def __init__( self.equiv_classes = equiv_classes else: self.equiv_classes = {} + + def generate_equiv_classes(self): for var in self.variables: if var in self.inputs or var in self.equiv_classes: continue @@ -113,7 +112,7 @@ def generate_timesteps(self): def marginalize(self, target): pass - def print_structure(self, pos=None): + def print_structure(self, pos=None, font=12, node_size=1000): G = nx.DiGraph() G.add_edges_from( [ @@ -123,7 +122,7 @@ def print_structure(self, pos=None): ] ) plt.figure(figsize=(10, 10)) - nx.draw_networkx(G, with_labels=True, node_color="green", pos=self.pos) + nx.draw_networkx(G, with_labels=True, node_color="green", pos=self.pos, font_size=font, node_size=node_size) plt.show() def find_live_paths(self, intervention): @@ -149,12 +148,9 @@ def find_live_paths(self, intervention): del paths[1] return paths - def print_setting(self, total_setting, display=None): - labeler = lambda var: var + ": " + str(total_setting[var]) \ - if display is None or display[var] \ - else var + def print_setting(self, total_setting, font=12, node_size=1000): relabeler = { - var: labeler(var) for var in self.variables + var: var + ": " + str(total_setting[var]) for var in self.variables } G = nx.DiGraph() G.add_edges_from( @@ -170,7 +166,7 @@ def print_setting(self, total_setting, display=None): if self.pos is not None: for var in self.pos: newpos[relabeler[var]] = self.pos[var] - nx.draw_networkx(G, with_labels=True, node_color="green", pos=newpos) + nx.draw_networkx(G, with_labels=True, node_color="green", pos=newpos, font_size=font, node_size=node_size) plt.show() def run_forward(self, intervention=None): @@ -233,11 +229,14 @@ def sample_input(self, mandatory=None): def sample_input_tree_balanced(self, output_var=None, output_var_value=None): assert output_var is not None or len(self.outputs) == 1 + self.generate_equiv_classes() + if output_var is None: output_var = self.outputs[0] if output_var_value is None: output_var_value = random.choice(self.values[output_var]) + def create_input(var, value, input={}): parent_values = random.choice(self.equiv_classes[var][value]) for parent in parent_values: From e64338a6c329413d058e33426e08e570ae1fb79b Mon Sep 17 00:00:00 2001 From: atticusg Date: Sun, 5 May 2024 10:39:43 -0700 Subject: [PATCH 14/16] move visualization file --- pyvene/{models => Analysis}/visualization.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) rename pyvene/{models => Analysis}/visualization.py (71%) diff --git a/pyvene/models/visualization.py b/pyvene/Analysis/visualization.py similarity index 71% rename from pyvene/models/visualization.py rename to pyvene/Analysis/visualization.py index 0c6b6006..81d22cea 100644 --- a/pyvene/models/visualization.py +++ b/pyvene/Analysis/visualization.py @@ -1,10 +1,11 @@ import seaborn +import torch def rotation_token_heatmap(rotate_layer, tokens, token_size, variables, - intervention_sizes): + intervention_size): W = rotate_layer.weight.data in_dim, out_dim = W.shape @@ -19,12 +20,12 @@ def rotation_token_heatmap(rotate_layer, for j in range(len(variables)): row = [] for i in range(len(tokens)): - row.append(W[i*token_size:(i+1)*token_size, j*intervention_size:(j+1)*intervention_size].sum()) + row.append(torch.norm(W[i*token_size:(i+1)*token_size, j*intervention_size:(j+1)*intervention_size])) mean = sum(row) heatmap.append([x/mean for x in row]) return seaborn.heatmap(heatmap, - xticklabels=variables, - yticklabels=tokens) + xticklabels=tokens, + yticklabels=variables) From ca26dd4fec7e190b1e20fa6aaa2ead53635cb580 Mon Sep 17 00:00:00 2001 From: atticusg Date: Sun, 5 May 2024 14:44:32 -0700 Subject: [PATCH 15/16] Modify Causal Model Unit Test --- pyvene/Analysis/visualization.py | 6 +----- tests/unit_tests/CausalModelTestCase.py | 1 + 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/pyvene/Analysis/visualization.py b/pyvene/Analysis/visualization.py index 81d22cea..cd70ea4c 100644 --- a/pyvene/Analysis/visualization.py +++ b/pyvene/Analysis/visualization.py @@ -25,8 +25,4 @@ def rotation_token_heatmap(rotate_layer, heatmap.append([x/mean for x in row]) return seaborn.heatmap(heatmap, xticklabels=tokens, - yticklabels=variables) - - - - + yticklabels=variables) \ No newline at end of file diff --git a/tests/unit_tests/CausalModelTestCase.py b/tests/unit_tests/CausalModelTestCase.py index c37178c7..90996afc 100644 --- a/tests/unit_tests/CausalModelTestCase.py +++ b/tests/unit_tests/CausalModelTestCase.py @@ -34,6 +34,7 @@ def setUpClass(self): self.parents, self.functions ) + self.causal_model.generate_equiv_classes() def test_initialization(self): inputs = ['A', 'B'] From 49838a885f1cbbf4b8ab0d9f17292df63e887bb1 Mon Sep 17 00:00:00 2001 From: atticusg Date: Sun, 5 May 2024 14:45:59 -0700 Subject: [PATCH 16/16] Change folder name --- pyvene/{Analysis => analyses}/visualization.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename pyvene/{Analysis => analyses}/visualization.py (100%) diff --git a/pyvene/Analysis/visualization.py b/pyvene/analyses/visualization.py similarity index 100% rename from pyvene/Analysis/visualization.py rename to pyvene/analyses/visualization.py