Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CUDA option to use the max release threshold for the default memory pool #5429

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

YavorGIvanov
Copy link

@YavorGIvanov YavorGIvanov commented Feb 9, 2024

I was using the GGML library for running multiple sequential inference jobs one after the other with big batch size. The goal was to maximize throughput, so the batch size was tuned to basically use maximum GPU memory the device I was running it supported (in my case A10 with 24GB). This led to the KV cache allocation becoming close to 16GB. I made this allocation at the beginning of every job. I noticed 300-400ms spikes when allocating the KV cache backend buffer for every job except the first one (It was 10-20ms for the first one). I tried multiple things without success and then asked about in the NVidia developer forum - https://forums.developer.nvidia.com/t/16gb-cudamalloc-on-a10-24gb-takes-300-400ms-after-previous-cudafree/281839.

I tested that setting the maximum release threshold of the default memory pool solved the problem as it basically allows for the pool to not release the allocated memory from previous jobs to the OS. This actually also improved the speed of the other allocations I was doing too as I was just getting memory from the pool.

I think this is sensible default when you are sure the process running LLAMA is the main one using the device. I added an option in the CMakeLists and CUDA backend to enable it. Not making it enabled by default as I suppose it may lead to potential memory problems in use cases where there are multiple processes trying to alloc a lot of GPU memory. I observed that it is necessary to use the Stream-Ordered CUDA memory allocation API for this to actually make a difference. This is why I changed cudaMalloc and cudaFree to cudaMallocAsync and cudaFreeAsync + sync stream.

My initial thought was to add a "release_threshold_mb" parameter to the ggml_backend_cuda_init(..) function, but this will lead to changing all the GGML examples and also I saw this is not being done currently for other options. Probably it makes sense in the future to provide a second argument to the ggml_backend_cuda_init(..) function with a struct containing all the backend specific parameters/configurations. I can move the option or change it to a MB option instead of ON/OFF if you prefer so.

@slaren
Copy link
Collaborator

slaren commented Feb 9, 2024

The stream-ordered allocator is not available on every device and cannot be used unconditionally. We had a lot of issues with it when an attempt was made to convert the CUDA pool to use it. Additionally, the global device streams cannot be used in a buffer, these streams are effectively private to the backend instance, and eventually will be moved entirely there (the only reason they are still global is to simplify the transition to ggml-backend).

I suspect that this is not the right place to fix this problem. Memory allocation should not be expected to be fast, and the best approach here may be to reuse the buffers instead.

@YavorGIvanov
Copy link
Author

YavorGIvanov commented Feb 9, 2024

The stream-ordered allocator is not available on every device and cannot be used unconditionally. We had a lot of issues with it when an attempt was made to convert the CUDA pool to use it.

Ok. Got it. Is that so for older devices or mainly HIP and ROCm ?

Additionally, the global device streams cannot be used in a buffer, these streams are effectively private to the backend instance, and eventually will be moved entirely there (the only reason they are still global is to simplify the transition to ggml-backend).

This also makes sense. I think there are certain use cases where making the alloc/free async on a stream would result in gaining some perf, but I guess it doesn't fit the ggml-backend/buffer design.

I suspect that this is not the right place to fix this problem. Memory allocation should not be expected to be fast, and the best approach here may be to reuse the buffers instead.

However, I do not agree with this. Having 300-400ms spikes when allocating through the GGML backend, which can obviously be mitigated by using the memory pool max release threshold is a huge deal. Obviously the code can be made a bit more complex and the GGML backend buffer can be reused thus eliminating the need for this setting, but I feel we are not taking advantage of the full capability of the CUDA backend this way.

I am writing a custom CUDA backend, so if you confirm that this has no way of fitting the ggml-backend implementation even as an option, I will close the PR. I initially opened it, because I wanted feedback for whether adding such an option is possible and was not sure what is the proper way to do it, but felt it would be very useful for other users of the library. However, I am can just use a similar change for my use case in the custom CUDA backend as I do not want to increase code complexity in order to fit to the general ggml-backend API.

@slaren
Copy link
Collaborator

slaren commented Feb 9, 2024

It is not designed to be used in this way, but I think it would be possible to create different ggml-backend buffer type that behaves asynchronously. You could pass a ggml_backend when creating this buffer type, and buffers allocated from this buffer type would use the stream of this ggml_backend instance to run all the operations on.

The problems with the stream-ordered allocator are not unfixable either, but at the very least it would be necessary to check that the feature is available before using it. I think it should be possible to use the CUDA default stream with cudaMallocAsync, or just another dummy stream if that's not allowed by CUDA. It should still be optional so that it can be disabled in case that it doesn't work well in some systems. I don't remember the exact issues that we had when using the stream-ordered allocator in the CUDA pool, but it is not just HIP, some NVIDIA devices don't support it, and in some cases it caused issues even when supported.

However, I am not sure that it would be worth adding all of this complexity. It seems that it is only an issue for applications that allocate and deallocate memory very frequently. Using the stream-ordered allocator in this way does the same thing as reusing the same buffers, it keeps the memory in a pool instead of returning it to the OS. Obviously this cannot be the default behaviour, since applications may expect the memory of freed buffers to be returned to the OS, so we also need to add another setting to enable or disable this feature. If you are not planning on using this feature for long since you are planning to write your own backend, I am not sure that we need to merge this.

@young-developer
Copy link
Contributor

I tried to add cuda memory pools(#3903) and it worked only for one GPU and failed for multiple GPU or different architectures. There was no mindblowing performance boost. So I just applied the first rule of optimization - https://wiki.c2.com/?RulesOfOptimization
If you want you can add it but take into account multiple use cases and devices etc etc.

@YavorGIvanov
Copy link
Author

YavorGIvanov commented Feb 9, 2024

I tried to add cuda memory pools(#3903) and it worked only for one GPU and failed for multiple GPU or different architectures. There was no mindblowing performance boost. So I just applied the first rule of optimization - https://wiki.c2.com/?RulesOfOptimization If you want you can add it but take into account multiple use cases and devices etc etc.

I see. However, you were adding it and using it instead of using the current pool in the CUDA backend. That won't actually solve the issue I was facing as there we are not using the pool to allocate currently in the ggml-cuda backend, but directly do cudaMalloc(..) as part of the ggml_backend_cuda_buffer_type_alloc_buffer() function. However, just changing the calls there to their async version does not fit the current design as explained by @slaren. Additionally 300-400ms initial KV cache allocation latency is definitely something worth optimizing and not "premature".

If you are not planning on using this feature for long since you are planning to write your own backend, I am not sure that we need to merge this.

I do not think is equivalent to keeping the KV cache buffer allocation in my case. The inference parameters (e.g. batch size) and input influence the KV cache size and they change. If I do not re-allocate the KV cache buffer when it is enough for the new parameters, then it will reach the maximum size it has ever been and other things in the same process, won't be able to use the unused GPU memory, This is not the case when using the maximum release threshold of the default memory pool. Additionally I am not sure whether there are not use cases where basically you just run inference multiple times with different parameters and batch size in a single process over time. Doesn't seem that crazy to me. I am not sure whether there are advantages in the general case though as I do not have a view of all the use cases of the library. For my use case it is obviously worth it.

I will continue using the change I made to my custom version of the ggml-cuda backend, but hoped I can also continue syncing with the latest version you update. Obviously the sync is easier if all my changes are in the library. However, the more important question is whether the potential added implementation complexity you describe (I understand it and I can do it) will bring enough benefit to justify itself. From your comments I get the sense that you think -> No it does not. So I guess it doesn't make sense for me to put some effort to add it if you don't think it has potential advantages (when it is supported) over the current implementation.

@slaren
Copy link
Collaborator

slaren commented Feb 9, 2024

The only parameters that would cause a change in the KV size are the number of slots (n_ctx) and the type (type_k, type_v). If your goal is to test different batch sizes, you can do that reallocating only the compute buffer.

void * dev_ptr;
cudaError_t err = cudaMalloc(&dev_ptr, size);
cudaError_t err = cudaMallocAsync(&dev_ptr, size, main_stream);
Copy link
Collaborator

@Artefact2 Artefact2 Feb 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This breaks the build on ROCm. Consider

diff --git a/ggml-cuda.cu b/ggml-cuda.cu
index 459e4376..9dacf34b 100644
--- a/ggml-cuda.cu
+++ b/ggml-cuda.cu
@@ -73,6 +73,8 @@
 #define cudaMalloc hipMalloc
 #define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault)
 #endif
+#define cudaMallocAsync hipMallocAsync
+#define cudaFreeAsync hipFreeAsync
 #define cudaMemcpy hipMemcpy
 #define cudaMemcpyAsync hipMemcpyAsync
 #define cudaMemcpyPeerAsync hipMemcpyPeerAsync

@YavorGIvanov
Copy link
Author

The only parameters that would cause a change in the KV size are the number of slots (n_ctx) and the type (type_k, type_v). If your goal is to test different batch sizes, you can do that reallocating only the compute buffer.

I am determining the kv cache size and allocating it based on the batch size or the current number of jobs submitted for inference. Preallocating the kv cache with size allowing the maximum supported number of jobs is the equivalent to what you are suggesting. That is usually fine, but still is less flexible. Will leave this PR as "demo" for now.

@YavorGIvanov YavorGIvanov marked this pull request as draft February 13, 2024 11:45
@ggerganov ggerganov added the demo Demonstrate some concept or idea, not intended to be merged label Feb 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
demo Demonstrate some concept or idea, not intended to be merged
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants