Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama: (proposal) propagating the results of graph_compute to the user interface #9525

Merged
merged 7 commits into from
Nov 13, 2024

Conversation

Xarbirus
Copy link
Contributor

In PR #9434 it was proposed to use enum instead of an error code. This ЗК proposes the first step towards this idea - passing the results of ggml_backend_sched_graph_compute to llama_decode/llama_encode so that they can be processed on the user side.

@Xarbirus
Copy link
Contributor Author

@conradev what do you think about this PR as a first step towards this?

@slaren
Copy link
Collaborator

slaren commented Sep 17, 2024

Will a failure during compute leave the KV cache in a consistent state? Otherwise this will not be very useful.

@Xarbirus
Copy link
Contributor Author

@slaren As far as I understand, now in case of a failure or abortion, kv_cache remains in the state in which it was configured in llama_kv_cache_find_slot. Processing the return code had no effect on this. In the next commit I will add code that processes the result of compute and returns kv_cache to the state before llama_kv_cache_find_slot

vignesh1507

This comment was marked as spam.

@Xarbirus Xarbirus force-pushed the ggml_status_to_user branch 2 times, most recently from 5535683 to 059e78c Compare October 14, 2024 13:37
@Xarbirus
Copy link
Contributor Author

@slaren @ggerganov Sorry to bother you, I would like to know if this request can be accepted?

@slaren
Copy link
Collaborator

slaren commented Oct 14, 2024

I think that making a copy of the entire state of the KV cache is too expensive to do on every evaluation. There must be a more lightweight way to do this, but that may also require more changes to the way the KV state is handled, at the very least whatever llama_kv_cache_find_slot does should be made absolutely clear to be able to tell if the operations to undo what it does is correct. I don't really know the way this is implemented well enough to say anything more specific.

@ggerganov
Copy link
Owner

ggerganov commented Oct 14, 2024

I agree that copying the entire KV state should be avoided since it can incur significant overhead. Right now, the KV state is partially managed by the user code as they need to explicitly keep track of the sequence lengths. This makes it easy for the user code to discard any result from the last batch processing, by simply submitting the next batch with the appropriate sequence positions (as if the last failed batch was never submitted). However, if in the future we want to delegate the logic for sequence positions internally to libllama then we need some sort of undo mechanism implemented to handle failures.

Overall, I'm not sure how proceed as well. On one hand I want to reimplement the KV state logic in a way that will allow us to support different KV cache implementations. But this can take a while to achieve. Also, regarding the error propagation, it will be easy for me to think on a case-by-case basis. What errors do you encounter that you would like to be able to recover from? For example, the examples already demonstrate how to handle the error when the KV cache is full. In my experience, I haven't encounter other types of errors that need to be handled at runtime.

@slaren
Copy link
Collaborator

slaren commented Oct 14, 2024

This makes it easy for the user code to discard any result from the last batch processing, by simply submitting the next batch with the appropriate sequence positions (as if the last failed batch was never submitted)

What about the changes that llama_kv_cache_find_slot makes? Shouldn't they be reverted?

What errors do you encounter that you would like to be able to recover from?

I think at the moment the only way the compute can fail is if the user cancels it with the abort callback. However that's only because backends just crash the application with an assert when something goes wrong, at some point that should be addressed by returning an error instead.

@ggerganov
Copy link
Owner

What about the changes that llama_kv_cache_find_slot makes? Shouldn't they be reverted?

Yes, actually they should be. I was incorrectly thinking just about the failure for finding a KV slot. For non-recurrent models, these can be reverted with appropriate set of llama_kv_cache_seq_rm calls. Not sure about the recurrent models though - might be more involved.

@Xarbirus
Copy link
Contributor Author

@ggerganov @slaren Thanks for sharing your thoughts.

For me the situation looks like this: now in case of a non-critical failure in decoding/encoding (if ggml_abort was not called) kv_cache.cells remain in a modified state. And users need to explicitly call llama_kv_cache_seq_rm to put the cache in order. Thus, simply returning error codes from ggml (I mean those changes that concern switches) will not change the logic of llama.cpp, which means the change itself is safe. And in the future it may allow replacing some of the ggml_abort with returning error codes.

As for reverting changes in kv_cache.cells, I agree that simple copying is too expensive. And I think there are 2 options:

  1. do not change anything (remove kv_cache_state_holder from the review) and continue to rely on the fact that users will revert everything by themselves.
  2. add cache revert logic to libllama (which will be a breaking change for those who revert cache explicitly after errors, because they will need to remove the call to llama_kv_cache_seq_rm from their code). Since copying is not an option, and a complete cache rework now looks really large-scale and beyond the scope of this PR, I have an idea to return a pair of indices from llama_kv_cache_find_slot. These indices will indicate the beginning and end of the slot (inclusive). And if the calculation fails, the cells within these boundaries will be cleared. This concerned non-recurrent models. For recurrent it seems that it is still easier to copy cells (there shouldn't be that many of them). But it seems that all the work with the cache of recurrent models should be implemented more complexly.

In the end, I would suggest, as part of this PR:

  1. to add the transfer of execution status from llama_graph_compute
  2. remove kv_cache_state_holder

And in a separate PR I would suggest adding rollback of llama_kv_cache_find_slot changes using slot start and end indices (should work for non-recurrent models).

WDYT?

@Xarbirus
Copy link
Contributor Author

So I added both solutions:

  • Commit 7c083f5 removes the cache revert logic and leaves only error codes returned by a llama_graph_compute
  • Commit 0c05c60 adds the cache revert logic

@slaren
Copy link
Collaborator

slaren commented Oct 21, 2024

The cache revert logic should apply to the entire batch, not just the latest ubatch. Otherwise, it is possible that only a fraction of the changes to the KV cache will be reverted, leaving the KV cache in an inconsistent state.

@Xarbirus
Copy link
Contributor Author

@slaren I updated the logic of llama_kv_slot_restorer, and also added a small fix for recurrent models to llama_kv_cache_find_slot.
It should also be taken into account that for recurrent models reverting changes of llama_kv_cache_find_slot will lead to a complete clearing of the state (recurrent models like Mamba or RWKV can't have a state partially erased).

@slaren
Copy link
Collaborator

slaren commented Oct 31, 2024

@Xarbirus thanks. I am not confident that I understand the logic of llama_kv_cache_find_slot well enough to review this, so I will leave the review of the implementation to @ggerganov or someone else.

@ggerganov
Copy link
Owner

@Xarbirus Sorry for this slow progress here. I'll be revisiting the KV cache implementation in the next days and try to figure out how to refactor and improve it. Will also try to resolve the error propagation as well.

@compilade
Copy link
Collaborator

compilade commented Nov 1, 2024

For hybrid models like Jamba (#7531), I also bumped into needing to keep both caches (recurrent states and self-Attention KV) in a consistent state when one of them fails to allocate a slot.

What I noticed is that checking if the allocation will succeed can be done fully ahead of modifying the cache.

There is no need to revert what can be avoided in the first place.

"Checking first" does not handle reverting the full batch, only the latest ubatch, but at least the cache is in a consistent state at ubatch boundaries, so llama_kv_cache_seq_rm could be safely called on failure.1

Not necessarily the way to go, but another possible approach.

Footnotes

  1. To avoid potentially clearing recurrent states on failure, maybe checking/planning the slots for the entire batch in advance could be done, which I'll keep in mind for the KV + recurrent cache redesign, and (depending on how it's implemented) could even allow dynamically merging a new batch to an existing "queue" of ubatches and gracefully handle failures for the new batch without affecting the existing current batch.

@Xarbirus
Copy link
Contributor Author

Xarbirus commented Nov 9, 2024

@compilade
Yes, I agree that we can split the slot search and its initialization in llama_kv_cache_find_slot. Maybe it makes sense to even split this function into two, but in this case I would wait for refactoring from @ggerganov.

But, as you correctly noted, this will allow us to leave the cache in a consistent state, but not in the state expected by the user of this function (since the cache can be "half-full"). Besides, in case of an error in llama_graph_compute we will still have to roll back the cache to the initial state.

@ggerganov I would like to know how you are doing with refactoring? Should I wait for it or do you have any ideas on how to improve the code from this request?

@Xarbirus Xarbirus force-pushed the ggml_status_to_user branch from bbf27cc to ee599f9 Compare November 9, 2024 18:39
@ggerganov
Copy link
Owner

@Xarbirus Initially I was planning to start working on this when I wrote the last comment, but I got sidetracked with improving the Metal backend. I guess this week will continue working on the Metal backend, so probably will get to the KV cache refactoring after that.

The latest version of this PR is OK to merge. Just add information in the comments in llama.h about the new behaviour of this function.

@Xarbirus
Copy link
Contributor Author

@ggerganov Done, please check.

@ggerganov ggerganov merged commit fb4a0ec into ggerganov:master Nov 13, 2024
52 checks passed
arthw pushed a commit to arthw/llama.cpp that referenced this pull request Nov 15, 2024
* llama: propagating the results of `graph_compute` to the user interface

* llama: reverting kv_cache in case of failed compute

* llama: `llama_kv_cache_state` was removed, only the result of `llama_graph_compute` is returned

* llama: restore a kv_cache in case of failed computation

* llama: correct reverting of the entire batch.
also updates `llama_kv_cache_find_slot`, will correctly count the number of `used` cells for recurrent models

* llama: updated comments

* llama : add comments about KV cache state after error

---------

Co-authored-by: Georgi Gerganov <[email protected]>
arthw pushed a commit to arthw/llama.cpp that referenced this pull request Nov 17, 2024
* llama: propagating the results of `graph_compute` to the user interface

* llama: reverting kv_cache in case of failed compute

* llama: `llama_kv_cache_state` was removed, only the result of `llama_graph_compute` is returned

* llama: restore a kv_cache in case of failed computation

* llama: correct reverting of the entire batch.
also updates `llama_kv_cache_find_slot`, will correctly count the number of `used` cells for recurrent models

* llama: updated comments

* llama : add comments about KV cache state after error

---------

Co-authored-by: Georgi Gerganov <[email protected]>
arthw pushed a commit to arthw/llama.cpp that referenced this pull request Nov 18, 2024
* llama: propagating the results of `graph_compute` to the user interface

* llama: reverting kv_cache in case of failed compute

* llama: `llama_kv_cache_state` was removed, only the result of `llama_graph_compute` is returned

* llama: restore a kv_cache in case of failed computation

* llama: correct reverting of the entire batch.
also updates `llama_kv_cache_find_slot`, will correctly count the number of `used` cells for recurrent models

* llama: updated comments

* llama : add comments about KV cache state after error

---------

Co-authored-by: Georgi Gerganov <[email protected]>
@Xarbirus Xarbirus deleted the ggml_status_to_user branch November 23, 2024 16:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants