diff --git a/.gitignore b/.gitignore index 54109c0..3d18747 100755 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,5 @@ *.pt *.pth +*pyc +*build* +*.egg* diff --git a/README.md b/README.md index 7455c49..45d8105 100755 --- a/README.md +++ b/README.md @@ -9,30 +9,59 @@ As depicted in following examples, one is able to highlight various discriminati [Paper Link](https://arxiv.org/abs/2301.08110) +## roadmap + - continue to cleanup repo + - i.p. remove Explainer class and other overhead + - more examples + - hf integration? ## prelim -This repo includes the XAI methods AtMan, Chefer, and a Captum interface for IG, GradCam etc. for the language-model GPT-J and vision-language model [MAGMA](https://github.com/Aleph-Alpha/magma) and [BLIP](https://colab.research.google.com/github/salesforce/BLIP). +This repo includes the XAI methods AtMan, Chefer, and a Captum interface for IG, GradCam etc. for the language-model GPT-J and vision-language model [MAGMA](https://github.com/Aleph-Alpha/magma) and [BLIP](https://colab.research.google.com/github/salesforce/BLIP). (Big props to Mayukh Deb.) To install all required dependencies, run the following command, e.g. in a conda environment with python3.8: ``` bash startup-hook.sh ``` +Note: further model-checkpoints will be downloaded when executing for the first time. Sometimes CLIP fails to verify on the first execution -> running again works usually. -# examples -## image-text/ MAGMA -TODO: examples for different methods (script + image) +# examples with MAGMA ``` cd atman-magma -python example_explain_panda.py +``` +## image-text/ MAGMA x AtMan +requires 1 RTX 3090 + +``` +python example_explain_panda_atman.py +``` + +## image-text/ MAGMA x Chefer +requires 1 A100 + +``` +python example_explain_panda_chefer.py run plot_panda.ipynb ``` +## image-text/ MAGMA x Captum IxG, ... +requires 1 A100 + +``` +python example_explain_panda_captum.py +run plot_panda.ipynb +``` + +## image-text/ rollout +on it + ## image-text/ BLIP +on it ## text/ GPT-J +on it -# more to read +# Method and Evaluation ![steering and measuring](figs/fig2.png) diff --git a/atman-magma/example_attention_rollout.py b/atman-magma/example_attention_rollout.py index d4462f9..a49d9a9 100755 --- a/atman-magma/example_attention_rollout.py +++ b/atman-magma/example_attention_rollout.py @@ -1,5 +1,5 @@ from atman_magma.magma import Magma -from atman.attention_rollout import AttentionRollout +from atman_magma.attention_rollout import AttentionRolloutMagma print('loading model...') model = Magma.from_checkpoint( diff --git a/atman-magma/example_big_loop.py b/atman-magma/example_big_loop.py deleted file mode 100755 index ead71f4..0000000 --- a/atman-magma/example_big_loop.py +++ /dev/null @@ -1,56 +0,0 @@ -import yaml -# from atman_magma.magma import Magma -# from atman_magma.explainer import Explainer -from atman_magma.logit_parsing import get_delta_cross_entropies, get_delta_logits -# from atman_magma.openimages_eval import run_eval - - -# from multimodal_explain_eval.dataloader import DataLoader -# from multimodal_explain_eval.utils import load_json_as_dict - -with open("config.yml", "r") as stream: - config = yaml.safe_load(stream) - -output_folder_root = config["files"]["output_dir"] - -conceptual_suppression_threshold_values = [0.7] - -suppression_factor_values = [0.1] - -manipulate_attn_scores_after_scaling_values = [True, False] - -modify_suppression_factor_based_on_cossim_values = [True, False] - -possible_logit_parsing_functions = [get_delta_cross_entropies, get_delta_logits] - -# metadata = load_json_as_dict( -# filename = config['files']['metadata_filename'] -# ) -# dataloader = DataLoader( -# metadata=metadata -# ) - -# print('loading model...') -# model = Magma.from_checkpoint( -# checkpoint_path = './magma_checkpoint.pt', -# device = 'cuda:0' -# ) - -output_folders = [] - -for conceptual_suppression_threshold in conceptual_suppression_threshold_values: - for suppression_factor in suppression_factor_values: - for ( - manipulate_attn_scores_after_scaling - ) in manipulate_attn_scores_after_scaling_values: - for ( - modify_suppression_factor_based_on_cossim - ) in modify_suppression_factor_based_on_cossim_values: - for logit_parsing_fn in possible_logit_parsing_functions: - - ## x if x is not None else None - output_folder_name = f"{output_folder_root}/conceptual_suppression_threshold_{conceptual_suppression_threshold if conceptual_suppression_threshold is not None else None}_suppression_factor_{suppression_factor}_manipulate_attn_scores_after_scaling_{manipulate_attn_scores_after_scaling}_modify_suppression_factor_based_on_cossim_{modify_suppression_factor_based_on_cossim}_logit_parsing_fn_{logit_parsing_fn.__qualname__}" - - print(output_folder_name) - -print("eval complete :)") diff --git a/atman-magma/example_explain_panda.py b/atman-magma/example_explain_panda_atman.py similarity index 61% rename from atman-magma/example_explain_panda.py rename to atman-magma/example_explain_panda_atman.py index c40d954..771bbb8 100755 --- a/atman-magma/example_explain_panda.py +++ b/atman-magma/example_explain_panda_atman.py @@ -2,6 +2,10 @@ from atman_magma.explainer import Explainer from atman_magma.utils import split_str_into_tokens from atman_magma.logit_parsing import get_delta_cross_entropies +import matplotlib.pyplot as plt +import cv2 +from atman_magma.outputs import DeltaCrossEntropiesOutput +import numpy as np print('loading model...') @@ -11,13 +15,10 @@ device = device ) - -''' -Image example -''' from magma.image_input import ImageInput import PIL.Image as PilImage + ex = Explainer( model = model, device = device, @@ -26,7 +27,6 @@ conceptual_suppression_threshold = 0.75 ) - prompt =[ ## supports urls and path/to/image ImageInput('',pil=PilImage.open('openimages-panda.jpg')), @@ -36,19 +36,10 @@ ## returns a tensor of shape: (1, 149, 4096) embeddings = model.preprocess_inputs(prompt.copy()) -## returns a list of length embeddings.shape[0] (batch size) -# output = model.generate( -# embeddings = embeddings, -# max_steps = 5, -# temperature = 0.001, -# top_k = 1, -# top_p = 0.0, -# ) -# completion = output[0] - +label ='Panda' logit_outputs = ex.collect_logits_by_manipulating_attention( prompt = prompt.copy(), - target = 'Panda', + target = label, max_batch_size=1, # prompt_explain_indices=[i for i in range(10)] ) @@ -57,4 +48,20 @@ output = logit_outputs ) -results.save('output.json') +image_filename = 'openimages-panda.jpg' + +label_tokens = model.tokenizer.encode(label) + +image = np.zeros((12,12)) +for i in range(len(label_tokens)): + image += results.show_image(image_token_start_idx = 0, target_token_idx= i) **2 + +# image[image<0.6]=1.0 +fig, ax = plt.subplots(nrows=1, ncols=2, figsize = (15 , 6)) +title = f'' +fig.suptitle(title) +ax[0].imshow(cv2.cvtColor(cv2.imread(image_filename), cv2.COLOR_BGR2RGB)) +ax[1].imshow(image) + +fig.savefig('panda-explained-atman.jpg') +print('panda-explained-atman.jpg') diff --git a/atman-magma/example_explain_panda_captum.py b/atman-magma/example_explain_panda_captum.py new file mode 100755 index 0000000..f883660 --- /dev/null +++ b/atman-magma/example_explain_panda_captum.py @@ -0,0 +1,68 @@ +from PIL import Image +import matplotlib.pyplot as plt + +from atman_magma.captum_helper import ( + CaptumMagma, +) +from multimodal_explain_eval.utils import check_if_a_or_an_and_get_prefix +import numpy as np +from atman_magma.magma import Magma +from magma.image_input import ImageInput +from PIL import Image + +from captum.attr import IntegratedGradients, InputXGradient, GuidedGradCam +from captum.attr import LayerGradCam + + +targets = 'Panda' +final_img = Image.open('openimages-panda.jpg') + + +print('loading model...') +model = Magma.from_checkpoint( + checkpoint_path = "./mp_rank_00_model_states.pt", + device = 'cuda:0' +) + +cmagma = CaptumMagma(magma = model) +# captum_tool = IntegratedGradients(cmagma) +# captum_tool = GuidedGradCam(cmagma, layer = cmagma.magma.lm.transformer.h[0].ln_1) #cmagma.magma.image_prefix.enc.layer4[-1].conv3) #, layer = cmagma.magma.image_prefix.enc.layer4[-1].conv3) +captum_tool = InputXGradient(cmagma) +#captum_tool = IntegratedGradients(cmagma) #! set n_steps below + + +cmagma.mode='text' #hack- leave it as it is - just passes below's image embeddings thru ... + + +label_tokens = model.tokenizer.encode(targets) + +att_combined = np.zeros((12,12)) +for i in range(len(label_tokens)): + + text_prompt = f"This is a picture of {check_if_a_or_an_and_get_prefix(targets.lower())} " + if i >= 1: + text_prompt += model.tokenizer.decode(label_tokens[:i]) + + + prompt = [ + ImageInput(None, pil=final_img), + text_prompt + ] + + embeddings = cmagma.magma.preprocess_inputs(prompt) + + attribution = captum_tool.attribute( + embeddings, + target=label_tokens[i], + #n_steps = 1 #integ gradients parameters ! + ) + + att = attribution[0].abs().sum(dim = 1).cpu().detach().numpy()[:144].reshape(12,12) + + att_combined += att/att.max() + + +fig = plt.figure() +plt.imshow(att_combined) +fig.savefig('panda-explained-captum.jpg') +print('panda-explained-captum.jpg') diff --git a/atman-magma/example_chefer.py b/atman-magma/example_explain_panda_chefer.py similarity index 66% rename from atman-magma/example_chefer.py rename to atman-magma/example_explain_panda_chefer.py index 1d9d8bb..82bf009 100755 --- a/atman-magma/example_chefer.py +++ b/atman-magma/example_explain_panda_chefer.py @@ -3,6 +3,7 @@ from magma.image_input import ImageInput from atman_magma.chefer.method import CheferMethod from atman_magma.chefer.chefer_magma.magma import CheferMagma +import PIL.Image as PilImage device = 'cuda:0' model = CheferMagma.from_checkpoint( @@ -17,17 +18,18 @@ prompt =[ ## supports urls and path/to/image - ImageInput('https://www.art-prints-on-demand.com/kunst/thomas_cole/woods_hi.jpg'), + #ImageInput('https://www.art-prints-on-demand.com/kunst/thomas_cole/woods_hi.jpg'), + ImageInput('',pil=PilImage.open('openimages-panda.jpg')), 'This is a picture of a' ] embeddings = model.preprocess_inputs(prompt) relevance_maps = cm.run( - embeddings = embeddings, - target = ' cabin in the woods' + embeddings = embeddings, + target = 'Panda' ) fig = plt.figure() plt.imshow(relevance_maps[0]['relevance_map'].reshape(12,12)) -plt.show() -fig.savefig('chefer.jpg') +fig.savefig('panda-explained-chefer.jpg') +print('panda-explained-chefer.jpg') diff --git a/atman-magma/plot_panda.ipynb b/atman-magma/plot_panda.ipynb deleted file mode 100755 index ae20462..0000000 --- a/atman-magma/plot_panda.ipynb +++ /dev/null @@ -1,83 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "id": "a349b9b9", - "metadata": {}, - "outputs": [], - "source": [ - "import pandas as pd\n", - "import matplotlib.pyplot as plt" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b3c32c8a", - "metadata": {}, - "outputs": [], - "source": [ - "import cv2\n", - "\n", - "from atman_magma.outputs import DeltaCrossEntropiesOutput\n", - "\n", - "def show_result_from_df_idx(\n", - "\n", - "):\n", - " image_filename = 'openimages-panda.jpg'\n", - " \n", - " image = DeltaCrossEntropiesOutput.from_file(filename = f'panda.json').show_image(image_token_start_idx = 0, target_token_idx= 0) **2\n", - " \n", - " #image[image<0.6]=1.0\n", - " fig, ax = plt.subplots(nrows=1, ncols=2, figsize = (15 , 6))\n", - " title = f''\n", - " fig.suptitle(title)\n", - " ax[0].imshow(cv2.cvtColor(cv2.imread(image_filename), cv2.COLOR_BGR2RGB))\n", - " \n", - " \n", - " \n", - " fads = ax[1].imshow(image, cmap = 'gray')\n", - " fig.colorbar( fads, ax=ax[1])\n", - " \n", - " import numpy as np\n", - " image = image / image.max()\n", - " image = (image*255).astype(np.uint8)\n", - " cv2.imwrite('panda-explained.jpg', image)\n", - " \n", - " " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f90498f7", - "metadata": {}, - "outputs": [], - "source": [ - "show_result_from_df_idx() " - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.16" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/atman-magma/setup.sh b/atman-magma/setup.sh deleted file mode 100755 index 2ad41a6..0000000 --- a/atman-magma/setup.sh +++ /dev/null @@ -1,2 +0,0 @@ -pip install git+https://github.com/Aleph-Alpha/magma.git --ignore-requires-python -python3 setup.py develop diff --git a/determined_eval_wrapper.py b/determined/determined_eval_wrapper.py similarity index 100% rename from determined_eval_wrapper.py rename to determined/determined_eval_wrapper.py diff --git a/metadata.json b/determined/metadata.json similarity index 100% rename from metadata.json rename to determined/metadata.json diff --git a/run_atman.yml b/determined/run_atman.yml similarity index 100% rename from run_atman.yml rename to determined/run_atman.yml diff --git a/run_atman_layers.yml b/determined/run_atman_layers.yml similarity index 100% rename from run_atman_layers.yml rename to determined/run_atman_layers.yml diff --git a/run_eval.py b/determined/run_eval.py similarity index 100% rename from run_eval.py rename to determined/run_eval.py diff --git a/figs/tab1.png b/figs/tab1.png index b0adaf0..e6738c8 100755 Binary files a/figs/tab1.png and b/figs/tab1.png differ diff --git a/magma/README.md b/magma/README.md index 18f7a66..06b5d14 100755 --- a/magma/README.md +++ b/magma/README.md @@ -1,4 +1,4 @@ -# this is a clone of the MAGMA repo -- minor adjustments +# this is a clone of the MAGMA repo -- minor adjustments -- mainly to fix dependency # MAGMA -- Multimodal Augmentation of Generative Models through Adapter-based Finetuning ## Authors diff --git a/startup-hook.sh b/startup-hook.sh index 90d307c..3199be6 100755 --- a/startup-hook.sh +++ b/startup-hook.sh @@ -1,8 +1,8 @@ -mkdir -p ~/.cache/clip - -pip install --user deepspeed==0.6.0 -pip install --user typeguard==2.11.1 -pip install --user opencv-python-headless==4.2.0.34 -pip install --user ./magma -pip install --user ./atman-open-images-eval -pip install --user ./atman-magma +pip install torch==1.13.1 torchvision --index-url https://download.pytorch.org/whl/cu117 +pip install deepspeed==0.6.0 +pip install typeguard==2.11.1 +pip install opencv-python-headless==4.2.0.34 +pip install ./magma +pip install ./atman-open-images-eval +pip install ./atman-magma +pip install gdown==4.4.0 captum