-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
GPU Memory #539
Comments
@ross-h1 - Is 900 the size of the flattened variables? That's not a lot, and I would be surprised if you are running into memory issues unless you are collecting all the intermediate HMC fields (like the mass matrix). As @fehiepsi mentioned, you can just collect the |
@ross-h1 I think using |
@ross-h1 Thank you for your example. I too am running out of memory on gpu with millions of parameters, and using transform=lambda state: state.z just to pull back the samples I need. Your code with some modification works perfectly for me, except each cycle seems to run slower than the ones before. My hope was that the behavior of fori_collect was to just pick up where the last sampling left off but it each cycle takes longer. Here's a minimally reproducible example where:
It starts of at 25 items per second and drops to 12 by the end. Can anyone see a problem, or tell me if fori_collect should/shouldn't have approximately constant time per batch? Incorporating the ability to continue from previous sampling from MCMC would be hugely appreciated.
|
@rexdouglass The time might vary due to the change in the number of leapfrog steps (you are using NUTS and adapting step_size). You can sequentially get batches of samples with from numpyro.infer import NUTS, MCMC
mcmc = MCMC(NUTS(test1), 300, 100)
for i in range(10):
mcmc.run(random.PRNGKey(i))
batches = [mcmc.get_samples()]
mcmc._warmup_state = mcmc._last_state Does this resolve your issue? We might expose a method to do the job: |
Can confirm this runs in linear time (each batch takes approximately the same amount of time). Thank you for the quick help! I'll open an issue ticket recommending this as a feature, for the following reason. Even though your code is trivial to implement (but immensely appreciated), it isn't 100% obvious that someone ought to do this. As a newbie coming from Stan/TF-probability I wasn't sure that hitting run again would continue sampling the chain from where we left off. Here I see we need to seed it another random key and set the _warmup_state to where we left off, so I was right to ask. The ability is useful in my case where the GPU is memory bound and so small batches are the only way to go. |
@fehiepsi Here is the loop as you designed and a call to nvmlDeviceGetMemoryInfo to show the memory used at each iteration. Each loop the memory used strictly increases (running into an overflow on my actual sizes again but not in the toy example). Perhaps there's a garbage clearing step to add? XLA_PYTHON_CLIENT_ALLOCATOR="platform" is described as freeing on demand. I added explicitly deleting samples to no effect. [Updated with an explicit call to gc.collect() just in case, but no effect]
import os
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"]="platform" #memory management at go time #https://github.com/pyro-ppl/numpyro/issues/539
#https://github.com/pyro-ppl/numpyro/issues/735
import numpyro
import jax
numpyro.set_platform("gpu")
print(jax.__version__) #0.2.3
print(numpyro.__version__) #0.4.1
print(jax.config.FLAGS.jax_backend_target) #local
print(jax.lib.xla_bridge.get_backend().platform) #gpu
import numpy as np
import numpyro.distributions as dist
from jax import random
from numpyro.infer import MCMC, NUTS
from numpyro.infer.hmc import hmc
from numpyro.infer.util import initialize_model
from numpyro.util import fori_collect
import numpy as onp
import numpyro
import jax
from pynvml import *
GPU_mem_state=None
try:
nvmlInit()
GPUhandle=nvmlDeviceGetHandleByIndex(0)
numpyro.set_platform("gpu")
def GPU_mem_state():
info = nvmlDeviceGetMemoryInfo(GPUhandle)
return "Used GPU Memory MB {:,}".format(onp.round(info.used/1000000,2))
except:
print ("Cant initialise GPU, Using CPU")
def test1():
a = numpyro.sample("a", dist.Normal(0., 0.2), sample_shape=(365,100))
import gc
from numpyro.infer import NUTS, MCMC
mcmc = MCMC(NUTS(test1), 100, 100)
for i in range(10):
print("\n"+GPU_mem_state())
mcmc.run(random.PRNGKey(i))
samples = mcmc.get_samples()
trace = [onp.atleast_1d(onp.asarray(f)) for f in samples]
del samples
mcmc._warmup_state = mcmc._last_state
gc.collect() |
Interesting! We should look into this issue. I was looking into numpyro codebase, especially some caching mechanism, but couldn't find the source of the memory leak. Trying to installing some profiling tools from here to debug this but I am not familiar with those tools so it might take time. If you have any suggestions, please let me know. Really appreciate your help! :) |
Much appreciated. I'm new to numpyro so I can only be a guinnea pig and talk through ideas. One immediate thought is that it's not a small memory leak, it just looks small because of the toy test. I can do some experiments and calculate exactly how many bytes it's not releasing after each run. I thought maybe @ross-h1 turned to fori_collect out of necessity. Some background, this is a covid nowcast that I tried and failed to get to scale in Stan, and then Greta which wraps tf-probability. The ability to auto-magically run a large model on the gpu is a godsend. The current hurdle is just memory, and the unnecessary need to store both the model and the full samples on the gpu at the same time. Need some way to port them off as we go so I can sample a chain for a long time. Thanks very much for the help. |
Thanks for taking a look! I think I have found the reason for the memory leak. Will push a fix soon. |
Dear Numpyro's team @fehiepsi @neerajprad, Background:
Good News: Bad News: Further report (using only 2 chains): Multiple GPUs: Single TPU: Thanks for your help. System:Ubuntu 18.01 Python libraries:jax==0.2.8 |
@PaoloRanzi81 In 1 GPU, JAX does not support parallel map, so I guess you are drawing sample sequentially. Did you get a warning that If you have multiple GPUs, using the default one
I think the same applies for TPU. JAX does not support running parallel map on 1 single device. |
Thanks @fehiepsi for the kind feedback! I wanted to confirm that Numpyro works great (great results, identical to PyMC3!!!) below the 51000 MCMC samples + chains = 8 by using a GPU. Unfortunately, above that threshold, the RAM memory blows off.
https://github.com/pymc-devs/pymc3/blob/master/pymc3/sampling_jax.py Thanks in advance for looking into the issue. |
I see, thanks a lot! It seems that there are memory leaks with "sequential" method (all of the tests that you did point to this one). The issue seems not relevant to the default method "parallel" and "vectorized". I'll look into this. Edit: by the way, did you enable progress_bar or not (with
I think so. I don't have 2 GPUs to test but that's what I understand when we added those "parallel", "vectorized", "sequential" methods to draw chains of samples.
I am not sure how many devices a TPU exactly has. Testing on cloud tpu colabs, I saw 8 TPU devices from jax.lib import xla_bridge
xla_bridge.device_count() # return 8 so parallel method should run with
Currently, we added the following sentence in MCMC docs |
Thank @fehiepsi ! I really hope you can solve the Out-Of-Memory (OOM) error with the RAM which is blocking lots of people from using multi-GPUs or multi-TPUs with Numpyro! Below what I think the 3 priorities are. First priority: Updating README.md file in the Github Numpyro’s repositoryI am very surprised that nobody has ever tested Numpyro with multi-GPUs or multi-TPUs. In the README.md GPUs and TPUs seem to be a big selling point for Numpyro. Unfortunately, the README.md does not even mention those nasty OOM errors... Therefore, I think the README.md must be more truthful and transparent. It should clearly state that Numpyro does NOT work with GPUs and TPUs by using professional Bayesian models (i.e. complex models with 450000 MCMC samples). Indeed, the OOM error limits the use of GPUs and TPUs to very trivial models by using less than 50000 MCMC samples. Unless Numpyro gets rigorously tested with multi-GPUs or multi-TPUs and the OOM error solved, we can not hide the fact that OOM error is completely blocking Numpyro for professional analyses. I understand it is painful to admit that, but it is the reality we are facing. At the very least, the README.md should warn researchers and practitioners to NOT waste thousands of their working hours by trying Numpyro by GPUs or TPUs for professional user cases. Second priority: Solving the OOM errorThe massive consumption of RAM memory is a major roadblock for using Numpyro for professional cases. The OOM error must be solved. Unfortunately, it seems that OOM error is a very hard problem to solve. Not sure if there is a connection, but JAX’s issues about the OOM error are already very old and no definitive solution has been found yet. Indeed, there is a track record of people mentioning the monotonic increase of RAM in JAX almost 2 years ago. Even this very issue in Numpyro have been open from February 2020 and still no solution available. For my needs I can not bring Numpyro to production since the OOM error is limiting me to 50000 MCMC samples. I would rather prefer to have 10 times more MCMC samples than that! That’s the reason why I am using the “old” PyMC3 implementation of CPUs in production. I am sure PyMC3 could be considered slow and expensive because of the use of CPUs and overheads. However, PyMC3 is reliable and it can handle 450000 MCMC samples easily. Third priority: clear tutorial about Numpyro with multi-GPUs or multi-TPUsOnce the OOM error is solved, then I agree users would like to have a complete tutorial on how to use I believe that Numpyro, PyMC3 and TensorFlow Probability are the 3 main contenders in the Python ecosystem for Bayesian modeling. It would be nice to compare them in terms of speed, but not in terms of results or learning curve. In terms of results these 3 libraries should be -hopefully- almost identical. In terms of learning curve, it is difficult to test them against each other (because qualitative components of it). |
Thanks for your feedbacks, @PaoloRanzi81! FYI, we did tests for multiple GPUs and TPUs when we added chain methods and mentioned about those supports, e.g. here and here. But we haven't tested them rigorously. I agree that we should have a tutorial on multi-GPUs or multi-TPUs. We pinned an issue here to call for NumPyro users to contribute and share experience. We also made an issue here for a test suite for performance testing. Are you interested in contributing in one of those items? That would be very valuable for the community. :) Regarding the OOM, I'm not sure if the first action should be to add new claims in README, without investigating the issue. I think that is also a main theme of open source development. For example, if Looking like the claim here is |
Updated: I got no error for
If you use #893, you will see a progress bar that displays the progress. I also optimized memory requirement a bit there, that hopefully can relax your OOM disappointment a bit. Let's try to benchmark more when you isolate the root of the problem. Hopefully, we can diagnose the issue together and share the experience with the community. |
Thanks @fehiepsi for the nice improvments! I am slowly reading the documentation you have provided. Here my quick comments:
|
Hi @PaoloRanzi81, thanks for your feedback! Looking forward to hearing the result with the master branch for
Probably (my best guess) it is an issue that we just fixed in the master branch. I don't know the relationship of RAM vs GPU enough to give a clear explanation... About "sequential", the request from other users is to call
Probably. I did look at the memory usage of
Hmm, this is even smaller than the numbers that I tested. You can also test yourself with the same code because TPU is free to use in colab. In the gist that I posted in the last comment, the number of samples is 8 x 100,000, i.e. 800,000 samples. I guess I might miss something or there is a missing Let's hope that using the master branch with |
@fehiepsi
Unfortunately with Numpyro and |
@PaoloRanzi81 I created #899 to track down the issue because the previous issues in this topic have been resolved, and I don't want to send unnecessary notifications for other users. Please follow-up in the new issue that I created, with might catch more eyes from other users. Some of my thoughts:
|
GPU Memory
Hi guys, have been having some problems with GPU Memory allocation, especially with fori_collect. I’ve read the JAX note on this here:
https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html
The particular problem I have is with a large number of variables (circa 900) and dense_mass=True. HMCAdaptState which is stored for each sample then becomes very large as the number of samples increases (despite holding the same values post warmup). Is it possible to store HMCAdaptState as an attribute of the sampler, rather than the sample? Alternatively some kind of transfer between GPU to CPU memory and/or file backend post generation of each sample?
Thanks again for an awesome project! Ross
Here is a batched fori_collect, with samples stored as Numpy arrays after generation. This uses the Nvidia-ml-py3 library to access GPU memory usage:
(pip3 install nvidia-ml-py3)
Setup:
Import os
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"]="platform"
GPU_mem_state=None
try:
from pynvml import *
nvmlInit()
GPUhandle=nvmlDeviceGetHandleByIndex(0)
numpyro.set_platform('gpu')
def GPU_mem_state():
info = nvmlDeviceGetMemoryInfo(GPUhandle)
return 'Used GPU Memory MB {:,}'.format(onp.round(info.used/1000000,2))
except:
print ('Cant initialise GPU, Using CPU’)
Code:
def merge_flat(trace,samples):
if trace is None:
trace = [onp.atleast_1d(onp.asarray(f)) for f in samples]
else:
trace = [onp.concatenate([trace[k], onp.atleast_1d(onp.asarray(samples[k]))]) for k in range(len(trace))]
return trace
def collect_max(start_collect, num_iter, sample_kernel, last_state, max_n=25):
"""
As per numpyro fori_collect, but runs in batches of max_n size
and returns flattened tree back on the cpu via merge_flat.
"""
The text was updated successfully, but these errors were encountered: