Skip to content

Commit

Permalink
more cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
bjoern committed Nov 6, 2023
1 parent 7828983 commit 0555893
Show file tree
Hide file tree
Showing 17 changed files with 147 additions and 179 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
*.pt
*.pth
*pyc
*build*
*.egg*
41 changes: 35 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,30 +9,59 @@ As depicted in following examples, one is able to highlight various discriminati

[Paper Link](https://arxiv.org/abs/2301.08110)

## roadmap
- continue to cleanup repo
- i.p. remove Explainer class and other overhead
- more examples
- hf integration?

## prelim
This repo includes the XAI methods AtMan, Chefer, and a Captum interface for IG, GradCam etc. for the language-model GPT-J and vision-language model [MAGMA](https://github.com/Aleph-Alpha/magma) and [BLIP](https://colab.research.google.com/github/salesforce/BLIP).
This repo includes the XAI methods AtMan, Chefer, and a Captum interface for IG, GradCam etc. for the language-model GPT-J and vision-language model [MAGMA](https://github.com/Aleph-Alpha/magma) and [BLIP](https://colab.research.google.com/github/salesforce/BLIP). (Big props to Mayukh Deb.)

To install all required dependencies, run the following command, e.g. in a conda environment with python3.8:
```
bash startup-hook.sh
```
Note: further model-checkpoints will be downloaded when executing for the first time. Sometimes CLIP fails to verify on the first execution -> running again works usually.

# examples
## image-text/ MAGMA
TODO: examples for different methods (script + image)
# examples with MAGMA
```
cd atman-magma
python example_explain_panda.py
```
## image-text/ MAGMA x AtMan
requires 1 RTX 3090

```
python example_explain_panda_atman.py
```

## image-text/ MAGMA x Chefer
requires 1 A100

```
python example_explain_panda_chefer.py
run plot_panda.ipynb
```

## image-text/ MAGMA x Captum IxG, ...
requires 1 A100

```
python example_explain_panda_captum.py
run plot_panda.ipynb
```

## image-text/ rollout
on it

## image-text/ BLIP
on it

## text/ GPT-J
on it


# more to read
# Method and Evaluation

![steering and measuring](figs/fig2.png)

Expand Down
2 changes: 1 addition & 1 deletion atman-magma/example_attention_rollout.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from atman_magma.magma import Magma
from atman.attention_rollout import AttentionRollout
from atman_magma.attention_rollout import AttentionRolloutMagma

print('loading model...')
model = Magma.from_checkpoint(
Expand Down
56 changes: 0 additions & 56 deletions atman-magma/example_big_loop.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
from atman_magma.explainer import Explainer
from atman_magma.utils import split_str_into_tokens
from atman_magma.logit_parsing import get_delta_cross_entropies
import matplotlib.pyplot as plt
import cv2
from atman_magma.outputs import DeltaCrossEntropiesOutput
import numpy as np


print('loading model...')
Expand All @@ -11,13 +15,10 @@
device = device
)


'''
Image example
'''
from magma.image_input import ImageInput
import PIL.Image as PilImage


ex = Explainer(
model = model,
device = device,
Expand All @@ -26,7 +27,6 @@
conceptual_suppression_threshold = 0.75
)


prompt =[
## supports urls and path/to/image
ImageInput('',pil=PilImage.open('openimages-panda.jpg')),
Expand All @@ -36,19 +36,10 @@
## returns a tensor of shape: (1, 149, 4096)
embeddings = model.preprocess_inputs(prompt.copy())

## returns a list of length embeddings.shape[0] (batch size)
# output = model.generate(
# embeddings = embeddings,
# max_steps = 5,
# temperature = 0.001,
# top_k = 1,
# top_p = 0.0,
# )
# completion = output[0]

label ='Panda'
logit_outputs = ex.collect_logits_by_manipulating_attention(
prompt = prompt.copy(),
target = 'Panda',
target = label,
max_batch_size=1,
# prompt_explain_indices=[i for i in range(10)]
)
Expand All @@ -57,4 +48,20 @@
output = logit_outputs
)

results.save('output.json')
image_filename = 'openimages-panda.jpg'

label_tokens = model.tokenizer.encode(label)

image = np.zeros((12,12))
for i in range(len(label_tokens)):
image += results.show_image(image_token_start_idx = 0, target_token_idx= i) **2

# image[image<0.6]=1.0
fig, ax = plt.subplots(nrows=1, ncols=2, figsize = (15 , 6))
title = f''
fig.suptitle(title)
ax[0].imshow(cv2.cvtColor(cv2.imread(image_filename), cv2.COLOR_BGR2RGB))
ax[1].imshow(image)

fig.savefig('panda-explained-atman.jpg')
print('panda-explained-atman.jpg')
68 changes: 68 additions & 0 deletions atman-magma/example_explain_panda_captum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from PIL import Image
import matplotlib.pyplot as plt

from atman_magma.captum_helper import (
CaptumMagma,
)
from multimodal_explain_eval.utils import check_if_a_or_an_and_get_prefix
import numpy as np
from atman_magma.magma import Magma
from magma.image_input import ImageInput
from PIL import Image

from captum.attr import IntegratedGradients, InputXGradient, GuidedGradCam
from captum.attr import LayerGradCam


targets = 'Panda'
final_img = Image.open('openimages-panda.jpg')


print('loading model...')
model = Magma.from_checkpoint(
checkpoint_path = "./mp_rank_00_model_states.pt",
device = 'cuda:0'
)

cmagma = CaptumMagma(magma = model)
# captum_tool = IntegratedGradients(cmagma)
# captum_tool = GuidedGradCam(cmagma, layer = cmagma.magma.lm.transformer.h[0].ln_1) #cmagma.magma.image_prefix.enc.layer4[-1].conv3) #, layer = cmagma.magma.image_prefix.enc.layer4[-1].conv3)
captum_tool = InputXGradient(cmagma)
#captum_tool = IntegratedGradients(cmagma) #! set n_steps below


cmagma.mode='text' #hack- leave it as it is - just passes below's image embeddings thru ...


label_tokens = model.tokenizer.encode(targets)

att_combined = np.zeros((12,12))
for i in range(len(label_tokens)):

text_prompt = f"This is a picture of {check_if_a_or_an_and_get_prefix(targets.lower())} "
if i >= 1:
text_prompt += model.tokenizer.decode(label_tokens[:i])


prompt = [
ImageInput(None, pil=final_img),
text_prompt
]

embeddings = cmagma.magma.preprocess_inputs(prompt)

attribution = captum_tool.attribute(
embeddings,
target=label_tokens[i],
#n_steps = 1 #integ gradients parameters !
)

att = attribution[0].abs().sum(dim = 1).cpu().detach().numpy()[:144].reshape(12,12)

att_combined += att/att.max()


fig = plt.figure()
plt.imshow(att_combined)
fig.savefig('panda-explained-captum.jpg')
print('panda-explained-captum.jpg')
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from magma.image_input import ImageInput
from atman_magma.chefer.method import CheferMethod
from atman_magma.chefer.chefer_magma.magma import CheferMagma
import PIL.Image as PilImage

device = 'cuda:0'
model = CheferMagma.from_checkpoint(
Expand All @@ -17,17 +18,18 @@

prompt =[
## supports urls and path/to/image
ImageInput('https://www.art-prints-on-demand.com/kunst/thomas_cole/woods_hi.jpg'),
#ImageInput('https://www.art-prints-on-demand.com/kunst/thomas_cole/woods_hi.jpg'),
ImageInput('',pil=PilImage.open('openimages-panda.jpg')),
'This is a picture of a'
]
embeddings = model.preprocess_inputs(prompt)

relevance_maps = cm.run(
embeddings = embeddings,
target = ' cabin in the woods'
embeddings = embeddings,
target = 'Panda'
)

fig = plt.figure()
plt.imshow(relevance_maps[0]['relevance_map'].reshape(12,12))
plt.show()
fig.savefig('chefer.jpg')
fig.savefig('panda-explained-chefer.jpg')
print('panda-explained-chefer.jpg')
83 changes: 0 additions & 83 deletions atman-magma/plot_panda.ipynb

This file was deleted.

Loading

0 comments on commit 0555893

Please sign in to comment.