Skip to content

Commit

Permalink
adding rollout
Browse files Browse the repository at this point in the history
  • Loading branch information
bjoern committed Nov 7, 2023
1 parent ed5a6f2 commit f4b7509
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 155 deletions.
29 changes: 24 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,27 +42,33 @@ requires 1 A100

```
python example_explain_panda_chefer.py
run plot_panda.ipynb
```

## image-text/ MAGMA x Captum IxG, ...
requires 1 A100

```
python example_explain_panda_captum.py
run plot_panda.ipynb
```

## image-text/ rollout
on it
requires 1 RTX 3090

## image-text/ BLIP
on it
```
python example_explain_attention_rollout.py
```

## text/ GPT-J
```
python example_steering.py
python example_document_qa_sentence_level_explain.py
```

## image-text/ BLIP
on it



# Method and Evaluation

![steering and measuring](figs/fig2.png)
Expand All @@ -74,3 +80,16 @@ on it
![performance](figs/fig5.png)

![quantitative](figs/tab1.png)


# cite
```
@inproceedings{
deiseroth2023atman,
title={{ATMAN}: Understanding Transformer Predictions Through Memory Efficient Attention Manipulation},
author={Bj{\"o}rn Deiseroth and Mayukh Deb and Samuel Weinbach and Manuel Brack and Patrick Schramowski and Kristian Kersting},
booktitle={Thirty-seventh Conference on Neural Information Processing Systems},
year={2023},
url={https://openreview.net/forum?id=PBpEb86bj7}
}
```
31 changes: 15 additions & 16 deletions atman-magma/atman_magma/attention_rollout/attention_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def __init__(self, model):
device = device
)
prompt = ['hello world my name is Ben and i love burgers']
attention_rollout = AttentionRolloutMagma(model = model)
embeddings = attention_rollout.model.preprocess_inputs(prompt) ## prompt is a list
Expand All @@ -30,17 +30,17 @@ def __init__(self, model):
"""
self.model = model.eval()
self.next_token_index = -1

def forward_pass(self, embeddings):
with torch.no_grad():
return self.model.lm(inputs_embeds = embeddings, output_attentions = True)

def run(self, embeddings, return_next_token_output_only: bool = False):
assert embeddings.ndim == 3, 'Expected tensor with 3 dims: (Batch, Sequence, Embedding)'
outputs = self.forward_pass(embeddings = embeddings)

return self.get_attention_rollout_from_model_outputs(outputs, return_next_token_output_only = return_next_token_output_only)

def get_attention_rollout_from_model_outputs(self, outputs, return_next_token_output_only: bool = True):
all_attentions = outputs.attentions
_attentions = [att.cpu().detach().numpy() for att in all_attentions]
Expand All @@ -59,18 +59,18 @@ def get_attention_rollout_from_model_outputs(self, outputs, return_next_token_ou
joint_attentions[0] = res_att_mat[0]
for i in np.arange(1,layers):
joint_attentions[i] = res_att_mat[i].dot(joint_attentions[i-1])

if return_next_token_output_only:
return joint_attentions[:, self.next_token_index,:]
else:
return joint_attentions

def run_on_image(self, prompt, target, manipulate_last_n_tokens: int = None):
assert isinstance(prompt[0], ImageInput), f'Expected the first prompt item to be an ImageInput but got: {type(prompt[0])}'
assert isinstance(prompt[1], str), f'Expected the second prompt item to be an str but got: {type(prompt[1])}'

prompt_embeddings = self.model.preprocess_inputs(prompt) ## prompt is a list
target_embeddings = self.model.preprocess_inputs(target)
target_embeddings = self.model.preprocess_inputs([target])
# print('len target:', target_embeddings.shape[1])
embeddings = torch.cat(
[
Expand All @@ -79,18 +79,17 @@ def run_on_image(self, prompt, target, manipulate_last_n_tokens: int = None):
],
dim =1
)

joint_attentions = self.run(embeddings, return_next_token_output_only = False)

target_token_indices = [i for i in range(prompt_embeddings.shape[1]-1, embeddings.shape[1])]

heatmap = np.zeros((12,12))

for idx in target_token_indices:
heatmap += joint_attentions[0,idx,1:145].reshape(12,12)


if manipulate_last_n_tokens is not None:
heatmap[-1,-manipulate_last_tokens:]=0. ## set explanation values of last 2 tokens to 0
return heatmap

26 changes: 0 additions & 26 deletions atman-magma/example_attention_rollout.py

This file was deleted.

108 changes: 0 additions & 108 deletions atman-magma/example_explain.py

This file was deleted.

33 changes: 33 additions & 0 deletions atman-magma/example_explain_attention_rollout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from atman_magma.magma import Magma
from atman_magma.attention_rollout import AttentionRolloutMagma
import matplotlib.pyplot as plt

from magma.image_input import ImageInput
import PIL.Image as PilImage


print('loading model...')
model = Magma.from_checkpoint(
checkpoint_path = "./mp_rank_00_model_states.pt",
device = 'cuda:0'
)
ar = AttentionRolloutMagma(model = model)



prompt =[
## supports urls and path/to/image
#ImageInput('https://www.art-prints-on-demand.com/kunst/thomas_cole/woods_hi.jpg'),
ImageInput('',pil=PilImage.open('openimages-panda.jpg')),
'This is a picture of a'
]

relevance_maps = ar.run_on_image(
prompt=prompt,
target = 'Panda', # note rollout per se does not have a target
)

fig = plt.figure()
plt.imshow(relevance_maps.reshape(12,12))
fig.savefig('panda-explained-rollout.jpg')
print('panda-explained-rollout.jpg')

0 comments on commit f4b7509

Please sign in to comment.