-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Desiderata] Captum-like implementation for Inseq compatibility #1
Comments
Hey @gsarti, I'm glad that you like our paper! You do great work at While LRP substantially outperforms other methods, it has an initial 'set-up' cost i.e. there is currently no implementation in PyTorch that is able to automatically apply the rules to all operations in a PyTorch graph. For instance, in LRP we must apply the epsilon rule on every summation operation. This means, that if we have a line of code such as I'm not aware of a way to do this kind of code manipulation/graph manipulation on the fly. I just found this tutorial on torch.fx. Maybe this is the solution? As a consequence, I'm implementing LRP right now in the style of zennit, but I'm optimistic that we can somehow integrate it into Best greetings, and thank you again (: |
Thanks for your prompt reply @rachtibat! I see the issue with setup costs, thanks for clarifying! I had an in-depth look to Would the |
Hey @gsarti, awesome, that you already had so much experience with I'm not quite sure what you mean by multi-token generation, but I try to give you an idea, what is possible if someone wants to explain several tokens at once:
I hope this explains it, if it is unclear, you can ask again (: |
Thanks for the response @rachtibat! To clarify, the background to my question was that typically library like Captum provide an interface to streamline the attribution of a single forward output (the first bullet point you describe). However, there is no simple abstraction to automate the "one attribution per generation step" process you describe in the third bullet point (although in the case of Captum, they actually added something akin to this in v0.7). The main reason of The 2nd approach you mention (the one proposing a "superposition" of 3 attributions) looks very interesting, and I think it's the first time I see this idea! But I have a doubt: this would mean, effectively, taking the output logit of previous tokens (e.g. 2, and 5 in your example) when computing the forward for token N-4 and using it to propagate relevance back into the model. Don't you think this is a bit unnatural to extract rationales, provided only the last token when computing predictions at every generation step? Not sure what information the preceding embeddings would provide in this context. Curious to hear your thoughts! |
Hey, afaik transformers are trained with a next token prediction at any output position. If you look at the huggingface implementation of Llama 2 for instance you see that the labels for CrossEntropy are the inputs shifted by one. So the model actually predicts at each output token and not just the last token. Because of the causual masking in attention heads, each output position N can only see the prior N-1 input tokens and does an independent prediction. I already tried it and it is equal to computing the attribution for each output token separately and adding them up or computing a superimposed heatmap at once. This is also due to the fact that LRP is an additive explanatory model i.e. the attribution can be disentangled into several independent relevance flows. We described this phenomenon in this paper: So, I think this might be a feature only present in additive explanatory models. |
This is very interesting, you're right! I was thinking of inference, but it is true that at training time the model does indeed predict a token at every position. The fact that it results in a simple sum of independent relevance flows is definitely an upside of additive models, looking forward to test it out! :) |
Hi @rachtibat,
Great job on AttnLRP, your LRP adaptation seems very promising to attribute the behavior of Transformer-based LMs!
Provided you guys are still working on the codebase, I was wondering whether it would be possible to have an implementation that is interoperable with the Captum LRP class. This would allow us to include the method in the
inseq
library (reference: inseq-team/inseq#122), enabling out-of-the-box support for:🤗 transformers
.inseq
has been already used in a number of publications since its release last year, and having an accessible implementation of AttnLRP there would undoubtedly help to democratize the access to your method.From an implementation perspective, I'm not an LRP connaisseur but my understanding is that for ensuring Captum compatibility it would be enough to specify your custom propagation rules matching the base class provided here.
Let me know your thoughts, and congrats again on the excellent work!
The text was updated successfully, but these errors were encountered: