Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GPU Memory #539

Closed
ross-h1 opened this issue Feb 19, 2020 · 21 comments
Closed

GPU Memory #539

ross-h1 opened this issue Feb 19, 2020 · 21 comments
Labels
performance question Further information is requested usability

Comments

@ross-h1
Copy link

ross-h1 commented Feb 19, 2020

GPU Memory

Hi guys, have been having some problems with GPU Memory allocation, especially with fori_collect. I’ve read the JAX note on this here:

https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html

  • which has helped a bit, though does slow things down. Without using XLA_PYTHON_CLIENT_ALLOCATOR="platform", JAX seems very reluctant to give back memory after each batch of samples. I’m working on sampling in small batches and trying to transfer samples back to CPU memory after use. See code below for a version of fori_collect which batches sampling, in case of use to anyone…

The particular problem I have is with a large number of variables (circa 900) and dense_mass=True. HMCAdaptState which is stored for each sample then becomes very large as the number of samples increases (despite holding the same values post warmup). Is it possible to store HMCAdaptState as an attribute of the sampler, rather than the sample? Alternatively some kind of transfer between GPU to CPU memory and/or file backend post generation of each sample?

Thanks again for an awesome project! Ross

Here is a batched fori_collect, with samples stored as Numpy arrays after generation. This uses the Nvidia-ml-py3 library to access GPU memory usage:

(pip3 install nvidia-ml-py3)

Setup:

Import os

os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"]="platform"

GPU_mem_state=None
try:
from pynvml import *
nvmlInit()
GPUhandle=nvmlDeviceGetHandleByIndex(0)
numpyro.set_platform('gpu')
def GPU_mem_state():
info = nvmlDeviceGetMemoryInfo(GPUhandle)
return 'Used GPU Memory MB {:,}'.format(onp.round(info.used/1000000,2))
except:
print ('Cant initialise GPU, Using CPU’)

Code:

def merge_flat(trace,samples):
if trace is None:
trace = [onp.atleast_1d(onp.asarray(f)) for f in samples]
else:
trace = [onp.concatenate([trace[k], onp.atleast_1d(onp.asarray(samples[k]))]) for k in range(len(trace))]
return trace

def collect_max(start_collect, num_iter, sample_kernel, last_state, max_n=25):
"""
As per numpyro fori_collect, but runs in batches of max_n size
and returns flattened tree back on the cpu via merge_flat.
"""

n=0
last_val = last_state
trace = None
n_main, n_extra = np.divmod(num_iter, max_n)
with tqdm(total=num_iter) as pbar:
    for m in range(n_main):
        collect=max_n - max(n+max_n-start_collect,0)
        samples, last_val = fori_collect(collect, max_n, sample_kernel, last_val, return_last_val=True, progbar=False)
        pbar.update(max_n)
        if GPU_mem_state is not None: pbar.set_description(GPU_mem_state())
        if max_n-collect>0:
            trace=merge_flat(trace,jtu.tree_leaves(samples))
            del samples
        n+=max_n
    collect = n_extra - max(n + n_extra - start_collect, 0)
    if n_extra>0:
        samples, last_val = fori_collect(collect, n_extra, sample_kernel, last_val, return_last_val=True, progbar=False)
        pbar.update(n_extra)
        if n_extra-collect > 0:
            trace = merge_flat(trace, jtu.tree_leaves(samples))
	del samples
return trace,last_val
@fehiepsi
Copy link
Member

@ross-h1 You can use transform arg to only collect fields you want. E.g., transform=lambda state: state.z

@neerajprad
Copy link
Member

@ross-h1 - Is 900 the size of the flattened variables? That's not a lot, and I would be surprised if you are running into memory issues unless you are collecting all the intermediate HMC fields (like the mass matrix). As @fehiepsi mentioned, you can just collect the z field. Let us know if that solves your problem.

@neerajprad neerajprad reopened this Feb 19, 2020
@fehiepsi fehiepsi added the question Further information is requested label Mar 3, 2020
@fehiepsi
Copy link
Member

fehiepsi commented Mar 3, 2020

@ross-h1 I think using transform arg will resolve the issue for you. Please open this again if it doesn't help. For some large models, JAX might also require more memory than it should (I am not sure).

@fehiepsi fehiepsi closed this as completed Mar 3, 2020
@rexdouglass
Copy link

rexdouglass commented Jan 2, 2021

@ross-h1 Thank you for your example. I too am running out of memory on gpu with millions of parameters, and using transform=lambda state: state.z just to pull back the samples I need.

Your code with some modification works perfectly for me, except each cycle seems to run slower than the ones before. My hope was that the behavior of fori_collect was to just pick up where the last sampling left off but it each cycle takes longer.

Here's a minimally reproducible example where:

  1. It does only one thing, samples from a normal distribution
  2. I've commented out the trace_merge so there's no overhead from increasingly long results (we throw away each after sampling)
  3. I print the start and stop indexes so we know we're not accidentally sampling more.

It starts of at 25 items per second and drops to 12 by the end. Can anyone see a problem, or tell me if fori_collect should/shouldn't have approximately constant time per batch? Incorporating the ability to continue from previous sampling from MCMC would be hugely appreciated.

0%|          | 0/1000 [00:00<?, ?it/s]
 Sampled so far 0
To sample 100
Used GPU Memory MB 2,353.07:  10%|█         | 100/1000 [00:03<00:35, 25.28it/s]
 Sampled so far 100
To sample 100
Used GPU Memory MB 2,341.99:  20%|██        | 200/1000 [00:09<00:37, 21.50it/s]
 Sampled so far 200
To sample 100
Used GPU Memory MB 2,343.17:  30%|███       | 300/1000 [00:13<00:32, 21.67it/s]
 Sampled so far 300
To sample 100
Used GPU Memory MB 2,343.7:  40%|████      | 400/1000 [00:19<00:30, 19.57it/s] 
 Sampled so far 400
To sample 100
Used GPU Memory MB 2,344.03:  50%|█████     | 500/1000 [00:26<00:29, 17.16it/s]
 Sampled so far 500
To sample 100
Used GPU Memory MB 2,343.76:  60%|██████    | 600/1000 [00:34<00:26, 14.96it/s]
 Sampled so far 600
To sample 100
Used GPU Memory MB 2,344.48:  70%|███████   | 700/1000 [00:44<00:22, 13.07it/s]
 Sampled so far 700
To sample 100
Used GPU Memory MB 2,343.37:  80%|████████  | 800/1000 [00:55<00:17, 11.54it/s]
 Sampled so far 800
To sample 100
Used GPU Memory MB 2,343.96:  90%|█████████ | 900/1000 [01:07<00:09, 10.33it/s]
 Sampled so far 900
To sample 100
Used GPU Memory MB 2,344.35: 100%|██████████| 1000/1000 [01:20<00:00, 12.37it/s]
import os
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"]="platform" 
import numpyro
import jax
numpyro.set_platform("gpu") 

print(jax.__version__) #0.2.3
print(numpyro.__version__) #0.4.1
print(jax.config.FLAGS.jax_backend_target) #local
print(jax.lib.xla_bridge.get_backend().platform) #gpu

import numpy as np
import numpyro.distributions as dist
from jax import random
from numpyro.infer import MCMC, NUTS
from numpyro.infer.hmc import hmc
from numpyro.infer.util import initialize_model
from numpyro.util import fori_collect
import numpy as onp
import numpyro
import jax
from pynvml import *

def test1():
  a = numpyro.sample("a", dist.Normal(0., 0.2), sample_shape=(365,100)) #

rng_key = random.PRNGKey(0)
rng_key, rng_key_ = random.split(rng_key)

model_info = initialize_model(random.PRNGKey(0), test1, model_args=()) #model_args=(data, labels,)
init_kernel, sample_kernel = hmc(model_info.potential_fn, algo="NUTS")
hmc_state = init_kernel(model_info.param_info,
                         #trajectory_length=10,
                         num_warmup=300
                         )

max_batch_size=100
total_to_sample=1000
sampled_so_far=0
last_val=hmc_state
import tqdm
with tqdm.tqdm(total=total_to_sample) as pbar:
    while sampled_so_far<total_to_sample:
        print(sampled_so_far)
        print(to_sample)
        to_sample=min(max_batch_size,total_to_sample-sampled_so_far)
        samples, last_val = fori_collect(lower=sampled_so_far, 
                                         upper=sampled_so_far+to_sample, 
                                         body_fun=sample_kernel,
                                         init_val=last_val,
                                         return_last_val=True,
                                         progbar=False,
                                         transform=lambda state: model_info.postprocess_fn(state.z) #make sure you do this to limit what you pull back
                                         )
        pbar.update(to_sample)
        #trace=merge_flat(trace,jax.tree_leaves(samples))
        del samples
        sampled_so_far+=to_sample

@fehiepsi
Copy link
Member

fehiepsi commented Jan 3, 2021

@rexdouglass The time might vary due to the change in the number of leapfrog steps (you are using NUTS and adapting step_size).

You can sequentially get batches of samples with

from numpyro.infer import NUTS, MCMC

mcmc = MCMC(NUTS(test1), 300, 100)
for i in range(10):
    mcmc.run(random.PRNGKey(i))
    batches = [mcmc.get_samples()]
    mcmc._warmup_state = mcmc._last_state

Does this resolve your issue? We might expose a method to do the job: mcmc._warmup_state = mcmc._last_state if you think this feature is useful. Please open a separate feature request in that case. Thanks!

@rexdouglass
Copy link

rexdouglass commented Jan 3, 2021

Can confirm this runs in linear time (each batch takes approximately the same amount of time). Thank you for the quick help!

I'll open an issue ticket recommending this as a feature, for the following reason. Even though your code is trivial to implement (but immensely appreciated), it isn't 100% obvious that someone ought to do this. As a newbie coming from Stan/TF-probability I wasn't sure that hitting run again would continue sampling the chain from where we left off. Here I see we need to seed it another random key and set the _warmup_state to where we left off, so I was right to ask. The ability is useful in my case where the GPU is memory bound and so small batches are the only way to go.

@rexdouglass
Copy link

rexdouglass commented Jan 3, 2021

@fehiepsi
Apologies, the timing remains constant but the memory use does not.

Here is the loop as you designed and a call to nvmlDeviceGetMemoryInfo to show the memory used at each iteration. Each loop the memory used strictly increases (running into an overflow on my actual sizes again but not in the toy example).

Perhaps there's a garbage clearing step to add? XLA_PYTHON_CLIENT_ALLOCATOR="platform" is described as freeing on demand. I added explicitly deleting samples to no effect. [Updated with an explicit call to gc.collect() just in case, but no effect]

Used GPU Memory MB 3,070.36
sample: 100%|██████████| 400/400 [00:10<00:00, 39.76it/s, 31 steps of size 1.22e-01. acc. prob=0.82] 
  0%|          | 0/100 [00:00<?, ?it/s]
Used GPU Memory MB 2,280.98
sample: 100%|██████████| 100/100 [00:01<00:00, 69.82it/s, 31 steps of size 1.22e-01. acc. prob=0.83]
sample:   7%|▋         | 7/100 [00:00<00:01, 65.73it/s, 31 steps of size 1.22e-01. acc. prob=0.83]
Used GPU Memory MB 2,294.94
sample: 100%|██████████| 100/100 [00:01<00:00, 66.56it/s, 31 steps of size 1.22e-01. acc. prob=0.83]
  0%|          | 0/100 [00:00<?, ?it/s]
Used GPU Memory MB 2,306.93
sample: 100%|██████████| 100/100 [00:01<00:00, 68.35it/s, 31 steps of size 1.22e-01. acc. prob=0.83]
sample:   6%|▌         | 6/100 [00:00<00:01, 59.24it/s, 31 steps of size 1.22e-01. acc. prob=0.83]
Used GPU Memory MB 2,321.94
sample: 100%|██████████| 100/100 [00:01<00:00, 70.24it/s, 31 steps of size 1.22e-01. acc. prob=0.83]
sample:   7%|▋         | 7/100 [00:00<00:01, 67.18it/s, 31 steps of size 1.22e-01. acc. prob=0.83]
Used GPU Memory MB 2,336.62
sample: 100%|██████████| 100/100 [00:01<00:00, 70.31it/s, 31 steps of size 1.22e-01. acc. prob=0.83]
  0%|          | 0/100 [00:00<?, ?it/s]
Used GPU Memory MB 2,349.6
sample: 100%|██████████| 100/100 [00:01<00:00, 70.87it/s, 31 steps of size 1.22e-01. acc. prob=0.83]
  0%|          | 0/100 [00:00<?, ?it/s]
Used GPU Memory MB 2,365.65
sample: 100%|██████████| 100/100 [00:01<00:00, 69.52it/s, 31 steps of size 1.22e-01. acc. prob=0.83]
sample:   7%|▋         | 7/100 [00:00<00:01, 67.55it/s, 31 steps of size 1.22e-01. acc. prob=0.83]
Used GPU Memory MB 2,379.22
sample: 100%|██████████| 100/100 [00:01<00:00, 70.54it/s, 31 steps of size 1.22e-01. acc. prob=0.83]
  0%|          | 0/100 [00:00<?, ?it/s]
Used GPU Memory MB 2,394.03
sample: 100%|██████████| 100/100 [00:01<00:00, 70.92it/s, 31 steps of size 1.22e-01. acc. prob=0.83]
import os
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"]="platform" #memory management at go time #https://github.com/pyro-ppl/numpyro/issues/539

#https://github.com/pyro-ppl/numpyro/issues/735
import numpyro
import jax
numpyro.set_platform("gpu") 

print(jax.__version__) #0.2.3
print(numpyro.__version__) #0.4.1
print(jax.config.FLAGS.jax_backend_target) #local
print(jax.lib.xla_bridge.get_backend().platform) #gpu

import numpy as np
import numpyro.distributions as dist
from jax import random
from numpyro.infer import MCMC, NUTS
from numpyro.infer.hmc import hmc
from numpyro.infer.util import initialize_model
from numpyro.util import fori_collect
import numpy as onp
import numpyro
import jax
from pynvml import *

GPU_mem_state=None
try:
    nvmlInit()
    GPUhandle=nvmlDeviceGetHandleByIndex(0)
    numpyro.set_platform("gpu")
    
    def GPU_mem_state():
        info = nvmlDeviceGetMemoryInfo(GPUhandle)
        return "Used GPU Memory MB {:,}".format(onp.round(info.used/1000000,2))
except:
    print ("Cant initialise GPU, Using CPU")
    
def test1():
  a = numpyro.sample("a", dist.Normal(0., 0.2), sample_shape=(365,100)) 

import gc
from numpyro.infer import NUTS, MCMC
mcmc = MCMC(NUTS(test1), 100, 100)
for i in range(10):
    print("\n"+GPU_mem_state())
    mcmc.run(random.PRNGKey(i))
    samples = mcmc.get_samples()
    trace = [onp.atleast_1d(onp.asarray(f)) for f in samples]
    del samples
    mcmc._warmup_state = mcmc._last_state
    gc.collect()

@fehiepsi
Copy link
Member

fehiepsi commented Jan 3, 2021

Interesting! We should look into this issue. I was looking into numpyro codebase, especially some caching mechanism, but couldn't find the source of the memory leak. Trying to installing some profiling tools from here to debug this but I am not familiar with those tools so it might take time. If you have any suggestions, please let me know. Really appreciate your help! :)

@fehiepsi fehiepsi reopened this Jan 3, 2021
@rexdouglass
Copy link

Much appreciated. I'm new to numpyro so I can only be a guinnea pig and talk through ideas.

One immediate thought is that it's not a small memory leak, it just looks small because of the toy test. I can do some experiments and calculate exactly how many bytes it's not releasing after each run. I thought maybe @ross-h1 turned to fori_collect out of necessity.

Some background, this is a covid nowcast that I tried and failed to get to scale in Stan, and then Greta which wraps tf-probability. The ability to auto-magically run a large model on the gpu is a godsend. The current hurdle is just memory, and the unnecessary need to store both the model and the full samples on the gpu at the same time. Need some way to port them off as we go so I can sample a chain for a long time.

Thanks very much for the help.

@fehiepsi
Copy link
Member

fehiepsi commented Jan 4, 2021

Thanks for taking a look! I think I have found the reason for the memory leak. Will push a fix soon.

@PaoloRanzi81
Copy link

Dear Numpyro's team @fehiepsi @neerajprad,
it is a couple of weeks that I am experiencing a nasty OOM RAM error with both “parallel” and “vectorized” versions of the Numpyro’s NUTS sampler. Today, I have tried numpyro==5.0.0 hoping that the OOM problem was solved. Unfortunately, it did not.

Background:

  • I am using Numpyro via PyMC3 (current master branch, which uses JAX + Numpyro) + Theano-PyMC (version 1.1).
  • The data-set is quite small (0.5 MB). However, the model (hierarchical + non-centered parametrization and several other Bayesian tricks) is quite big with lots of hyperpriors. In total, I have ~ 120 posterior distributions which need to be estimated by NUTS sampler.
  • The model has been extensively tested. It works great with PyMC3 3.10! The machine I have used with pure PyMC3 is the following: 32 vCPUs + 24 GB RAM. The NUTS sampler’s characteristics: 450000 total MCMC samples (namely: draws= 1000; tune = 12062; chains = 32). Within ~ 3 Hrs it finishes sampling. The RAM does never go above 6 GB.

Good News:
Using PyMC3 + Numpyro “parallel” version by vCPUs, Numpyro is able to perform similar to the pure PyMC3 3.10 version both in terms of run-time and RAM usage. Most importantly, the posterior distributions of the Numpyro version are the same as pure PyMC3 version!!! Congrats to Numpyro’s team!

Bad News:
I have tried the same model in Numpyro by using 1 GPU (NVIDIA Tesla 16 GB). Both “parallel” and “vectorized” versions blew the RAM off (with 39 GB RAM available, still Out-Of-Memory (OOM) error...). As described several times in Numpyro’s and JAX’s issues, I also confirm that the RAM does increase monotonically.
The highest MCMC samples I was able to run successfully (thus without getting OOM error) was a total of 51000 MCMC samples (namely: draws= 1000; tune = 5375; chains = 8). However, it took 27 GB of RAM, 20 GB more than pure the PyMC3 version. Everything above 51000 MCMC samples triggered a OOM error.

Further report (using only 2 chains):
I have noticed a very weird behavior. I have used either “parallel” or “vectorized” version by using a total of 10750 MCMC samples (namely: draws= 1000; tune = 5375; chains = 2). Still the OOM error was manifest even with such a small number of chains!
With the “vectorized” version I did see a noticeable difference in the GPU power usage: with less chains, less usage of the GPU. That's was expected. Nevertheless, the OOM was still present.
With the “parallel” version I was expecting the less strain on the RAM: I was thinking that with only 2 chains run in sequence it would unload some of the weight from the RAM to the GPU. Instead, that was not the case. It was hard to believe, but with “parallel” + chains = 2 the monotonic increase was even faster (0.01 GB RAM/sec) with respect to “vectorized” + chains = 8 (0.001 GB RAM/sec).
In conclusion, “vectorized” + chains = 8 + total MCMC samples = 51000 no OOM error. Instead, “parallel” + chains = 2 + total MCMC samples = 10750 the OOM error was present.

Multiple GPUs:
This OOM problem makes me very skeptic of trying Numpyro and its “parallel” version with multiple GPUs (again, each GPU running an individual chain). I am indeed afraid that the OOM error would still persist. What do you think about that?

Single TPU:
Further, I am afraid that the same OOM error would also jeopardize the use of a single TPU (not tried yet). With the TPU looked natural to me to run 1 chain/core, thus 8 chains in total by using Numpyro “parallel” on a single TPU. Correct me if I am wrong: Do you think that running Numpyro on a single TPU the OOM error would disappear? Or will it suffer of the same OOM’s curse that already afflicts GPUs?

Thanks for your help.

System:

Ubuntu 18.01
Cuda 10.01
Python 3.7.6

Python libraries:

jax==0.2.8
jaxlib==0.1.59+cuda101
numpyro==0.5.0
pymc3 (current master branch, 20200125)
Theano-PyMC==1.1.0

@fehiepsi
Copy link
Member

fehiepsi commented Jan 25, 2021

@PaoloRanzi81 In 1 GPU, JAX does not support parallel map, so I guess you are drawing sample sequentially. Did you get a warning that There are not enough devices to run parallel chains: expected ...? I haven't looked at chain_method="sequentially" recently so probably there is a leakage over there. Could you confirm that you are using chain_method="sequentially" in your case? Also want to know how many times you call .run method - which can be a hint for me to understand what's going on.

If you have multiple GPUs, using the default one chain_method="parallel", and num_chains <= num_GPUs, then things might be good for you because it is the intended usage case.

Do you think that running Numpyro on a single TPU the OOM error would disappear? Or will it suffer of the same OOM’s curse that already afflicts GPUs?

I think the same applies for TPU. JAX does not support running parallel map on 1 single device.

@fehiepsi fehiepsi reopened this Jan 25, 2021
@PaoloRanzi81
Copy link

Thanks @fehiepsi for the kind feedback! I wanted to confirm that Numpyro works great (great results, identical to PyMC3!!!) below the 51000 MCMC samples + chains = 8 by using a GPU. Unfortunately, above that threshold, the RAM memory blows off.

  1. I confirm I have got the message “There are not enough devices to run parallel chains: expected… but got… Chains will be drawn sequentially”. I have overlooked chain_method=“sequential” so far. Anyhow, checking Numpyro’s source code it looks like even if I have typed inadvertently chain_method=“parallel”, it reverts automatically back to chain_method=“sequential”.
    After your suggestion, I have explicitly typed chain_method=“sequential” (and not “sequentially”) and I have test it. Unfortunately, same results: the RAM keeps increasing very fast monotonically until it breaks with OOM error by just using chains=2 + draws=1000 + tune=5375.

  2. Yes, what I am looking for is embarrassingly parallelism: each device should run a chain. As I have read from JAX documentation, 2 GPUs corresponds to 2 devices, whereas 1 TPU corresponds to 8 devices. Could you please definitely confirm that using 2 GPUs (or 1 TPU) and chain_method=“parallel” the OOM error disappear?

  3. I have found no documentation about using >= 2 GPUs and >= 1 TPU in Numpyro. Could you point me to some good tutorials about it. I want to be sure that somebody used successfully more than one device + chain_method=“parallel” without OOM error.

  4. I am bringing your attention to the PyMC3’s script below where Numpyro is called from PyMC3. Not sure how many times the .run method have been called, but it is mentioned once in the script (line 163).
    On my side, of course I have manually changed chain_method=“parallel” (line 159), accordingly. I have also put numpyro.set_platform(platform='gpu') in my personal script (and the GPU does work perfectly up to 85% !). Finally, I have tried to manually change “--xla_force_host_platform_device_count=2” in PyMC3 script in order to match the 2 chains I was testing. Unfortunately, no improvement… OOM error of the RAM is still a nasty issue…

https://github.com/pymc-devs/pymc3/blob/master/pymc3/sampling_jax.py

Thanks in advance for looking into the issue.

@fehiepsi
Copy link
Member

fehiepsi commented Jan 26, 2021

I see, thanks a lot! It seems that there are memory leaks with "sequential" method (all of the tests that you did point to this one). The issue seems not relevant to the default method "parallel" and "vectorized". I'll look into this. Edit: by the way, did you enable progress_bar or not (with chain_method="sequential")? I tried both options in my GPU but did not observe the memory leak. Probably the leak happens at something else because as you mentioned, RAM is increasing, not GPU memory is increasing. It would be hard to guess the issue without looking at the code.

Could you please definitely confirm that using 2 GPUs and chain_method=“parallel” the OOM error disappear?

I think so. I don't have 2 GPUs to test but that's what I understand when we added those "parallel", "vectorized", "sequential" methods to draw chains of samples.

1 TPU

I am not sure how many devices a TPU exactly has. Testing on cloud tpu colabs, I saw 8 TPU devices

from jax.lib import xla_bridge
xla_bridge.device_count()  # return 8

so parallel method should run with num_chains <= 8.

I have found no documentation about using >= 2 GPUs and >= 1 TPU in Numpyro

Currently, we added the following sentence in MCMC docs The method ‘parallel’ is used to execute the drawing process in parallel on XLA devices (CPUs/GPUs/TPUs), If there are not enough devices for ‘parallel’, we fall back to ‘sequential’ method to draw chains sequentially. and pointed to pmap for num_chains argument Number of MCMC chains to run. By default, chains will be run in parallel using jax.pmap(), failing which, chains will be run in sequence. Do you think that users need something else to run multiple chains on GPUs? I don't have multiple GPUs to test but I guess JAX supports pmap out-of-box. I would like to hear your thoughts and how to improve the docs here. :) I think this also applies to TPUs. JAX team should support those transforms pmap, vmap out-of-box. There might be some tips scattering over JAX docs that we don't know about. So it would be very helpful if we can add some more infos about acceleration devices in the docs. If you know any, please let me know.

@PaoloRanzi81
Copy link

Thank @fehiepsi ! I really hope you can solve the Out-Of-Memory (OOM) error with the RAM which is blocking lots of people from using multi-GPUs or multi-TPUs with Numpyro!

Below what I think the 3 priorities are.

First priority: Updating README.md file in the Github Numpyro’s repository

I am very surprised that nobody has ever tested Numpyro with multi-GPUs or multi-TPUs. In the README.md GPUs and TPUs seem to be a big selling point for Numpyro. Unfortunately, the README.md does not even mention those nasty OOM errors...

Therefore, I think the README.md must be more truthful and transparent. It should clearly state that Numpyro does NOT work with GPUs and TPUs by using professional Bayesian models (i.e. complex models with 450000 MCMC samples). Indeed, the OOM error limits the use of GPUs and TPUs to very trivial models by using less than 50000 MCMC samples.

Unless Numpyro gets rigorously tested with multi-GPUs or multi-TPUs and the OOM error solved, we can not hide the fact that OOM error is completely blocking Numpyro for professional analyses. I understand it is painful to admit that, but it is the reality we are facing. At the very least, the README.md should warn researchers and practitioners to NOT waste thousands of their working hours by trying Numpyro by GPUs or TPUs for professional user cases.

Second priority: Solving the OOM error

The massive consumption of RAM memory is a major roadblock for using Numpyro for professional cases. The OOM error must be solved. Unfortunately, it seems that OOM error is a very hard problem to solve. Not sure if there is a connection, but JAX’s issues about the OOM error are already very old and no definitive solution has been found yet. Indeed, there is a track record of people mentioning the monotonic increase of RAM in JAX almost 2 years ago. Even this very issue in Numpyro have been open from February 2020 and still no solution available.

For my needs I can not bring Numpyro to production since the OOM error is limiting me to 50000 MCMC samples. I would rather prefer to have 10 times more MCMC samples than that! That’s the reason why I am using the “old” PyMC3 implementation of CPUs in production. I am sure PyMC3 could be considered slow and expensive because of the use of CPUs and overheads. However, PyMC3 is reliable and it can handle 450000 MCMC samples easily.

Third priority: clear tutorial about Numpyro with multi-GPUs or multi-TPUs

Once the OOM error is solved, then I agree users would like to have a complete tutorial on how to use chain_method=“parallel” (and the other options) with more than one accelerator. Specifically one tutorial about 2 GPUs, one about 1 TPU and one about 2 TPUs would be nice. It would be great to add to each tutorial a benchmark about speed with the other probabilistic programming languages like PyMC3 and TensorFlow Probability by using the very same data-set (e.g. radon data-set from Gelman 2006).

I believe that Numpyro, PyMC3 and TensorFlow Probability are the 3 main contenders in the Python ecosystem for Bayesian modeling. It would be nice to compare them in terms of speed, but not in terms of results or learning curve. In terms of results these 3 libraries should be -hopefully- almost identical. In terms of learning curve, it is difficult to test them against each other (because qualitative components of it).
Instead bench-marking these 3 libraries in terms of speed should be a great added value to the tutorial. This way we should help inform best practices in finding the most cost-effective implementation in production (“Are CPUs or GPUs or TPUs which are the most cost-effective?”). Indeed, bench-marking speed is particularly useful in production in case Bayesian models need to run multiple times per day.

@fehiepsi
Copy link
Member

Thanks for your feedbacks, @PaoloRanzi81!

FYI, we did tests for multiple GPUs and TPUs when we added chain methods and mentioned about those supports, e.g. here and here. But we haven't tested them rigorously.

I agree that we should have a tutorial on multi-GPUs or multi-TPUs. We pinned an issue here to call for NumPyro users to contribute and share experience. We also made an issue here for a test suite for performance testing. Are you interested in contributing in one of those items? That would be very valuable for the community. :)

Regarding the OOM, I'm not sure if the first action should be to add new claims in README, without investigating the issue. I think that is also a main theme of open source development. For example, if arange(int(1e10)) raises OOM in GPU, it doesn't mean that all array libraries have to make a claim in README that using GPU is a bad idea in those cases. So could you make a separate issue to point out the problem that you faced, together with some reproducible code or hints that helps us diagnose your issue? The issue in this thread is for sequential drawing, that is you run MCMC to get some samples, then run it again (possibly with new data) to get more samples. And it has been resolved in a few days when the memory leak issue is raised.

Looking like the claim here is with num_samples~100000, num_chains~8, latent dimension~100, we will get OOM on acceleration devices (that means the output will be a 8 x 100000 x 100 array). I can make an attempt first to verify that on TPUs and on my single GPU. We can follow up in diagnosing your issue in a separate thread. What do you think?

@fehiepsi
Copy link
Member

Updated: I got no error for num_samples=100,000, num_chains=8, latent dimension=100 on TPUs (see this gist), and it is pretty fast (1min19s). On my single GPU (using the same script with numpyro.set_platform("gpu")), it is so slow (which is expected because GPU is not designed to do this job) to draw 100,000 samples per chain so I reduced num_samples to 10,000 per chain (so total samples is 80,000). FYI, here is the profile

(8, 10000, 100)
peak memory: 5462.81 MiB, increment: 5415.39 MiB
CPU times: user 9min 23s, sys: 6.47 s, total: 9min 29s
Wall time: 9min 25s

If you use #893, you will see a progress bar that displays the progress. I also optimized memory requirement a bit there, that hopefully can relax your OOM disappointment a bit. Let's try to benchmark more when you isolate the root of the problem. Hopefully, we can diagnose the issue together and share the experience with the community.

@PaoloRanzi81
Copy link

Thanks @fehiepsi for the nice improvments!

I am slowly reading the documentation you have provided. Here my quick comments:

  1. For my side it is no problem providing the model + data-set to you. I have got to ask to employer, though. What’s good about my model is that it works great with PyMC3 and total 450000 MCMC samples.
    If it is not possible to provide the actual data-set, I will try to make a small reproducible example about the OOM error I am facing with 1 GPU. Lastly, I wanted to confirm that by using CPUs Numpyro works perfectly (no matter how big the total MCMC samples are!!!) and the RAM does not monotonically increase, but it stays below 4-5 GB all the time.

  2. Unfortunately, I have already very little free time and I can not provide thorough tutorials. But I will do my best to test the recent Numpyro improvements against my model. It will take 1 week or so to give you some feedback. Let’s keep in touch.

  3. I do not have neither 2 GPUs or 1 TPU at the moment. I am waiting a reply from employer for it. The best I can do for now is testing the model against 1 GPU and against several vCPUs.

  4. Please note I do not think I have never used chain_method=“parallel” actually, because Numpyro reverts automatically back to chain_method=“sequential” when it finds 1 GPU only.
    I summarize here the settings which triggered OOM errors when I was using 1 GPU (I am using Numpyro's terminology now) .

  • chain_method=“sequential” + num_warmup=5375 + num_samples=1000 + num_chains=2 ;
  • chain_method=“vectorized” + num_warmup>=12000 + num_samples=1000 + num_chains=8 ;
  1. Very good that you have tested the TPU! I will try the same code for building a fake model. Perhaps TPU is not affected by OOM error. Fingers crossed!

  2. Running GPU for only 9 min could be not enough to spot the OOM. Try to run the GPU example for longer (hours not minutes) to see the OOM error. You should see it by naked eye by using htop command in Linux command line. E.g. in my case I have run the GPU for 4 Hrs before seeing 39 RAM blowing off.
    Further, please try again the TPU example and set the parameters to a more heavy use case (that what I actually need in production): chain_method=“parallel” + num_warmup= 55250 + num_samples=1000 + num_chains=8 .
    Remember I have ~ 120 latent variables and a very complex model (lots of hyper-priors with nested distributions). I am hoping the toy example you are using (below its snippet) is able to count for that.
    numpyro.sample("x", dist.Normal(0, 1).expand([100]))

  3. I am eager to test soon whether the _laxmap did improve things with 1 GPU!

  4. Sorry, but I did not understand your sentence of “...that is you run MCMC to get some samples, then run it again (possibly with new data) to get more samples.” My naive understanding was that by using chain_method=“sequential” one single chain is run at each time, but with the VERY same data-set. Moreover, still not clear why it happened to me that by chain_method=“sequential” when a single chain has finished, the RAM does not unload but its usage keeps increasing.

@fehiepsi
Copy link
Member

fehiepsi commented Jan 27, 2021

Hi @PaoloRanzi81, thanks for your feedback! Looking forward to hearing the result with the master branch for progress_bar=False and progress_bar=True. In my opinion, using progress_bar=True might be the best option for GPU because GPU is slow in sequential inference methods like MCMC and the small improvement of progress_bar=False would be negligible.

why it happened to me that by chain_method=“sequential” when a single chain has finished, the RAM does not unload but its usage keeps increasing

Probably (my best guess) it is an issue that we just fixed in the master branch. I don't know the relationship of RAM vs GPU enough to give a clear explanation... About "sequential", the request from other users is to call .run(...), then call .run(...) again sequentially. The implementation detail is a bit different from chain_method="sequential".

Running GPU for only 9 min could be not enough to spot the OOM

Probably. I did look at the memory usage of lax.map as saw that RAM is increasing after a period of time, but not too bad to be OOM - probably because the model I used is simple... Let's see what you get with the last fix.

please try again the TPU example and set the parameters to a more heavy use case (that what I actually need in production): chain_method=“parallel” + num_warmup= 55250 + num_samples=1000 + num_chains=8 .

Hmm, this is even smaller than the numbers that I tested. You can also test yourself with the same code because TPU is free to use in colab. In the gist that I posted in the last comment, the number of samples is 8 x 100,000, i.e. 800,000 samples. I guess I might miss something or there is a missing 0 here.

Let's hope that using the master branch with progress_bar=True (the default value) will resolve memory issue for you. In my case, with progress_bar=True, RAM stays constantly at 2.5% (of the total 32GB RAM) to draw 2 x 100,000 = 200,000 samples. With progress_bar=False, RAM stays constantly at 4% to draw 8 x 10,000 = 80,000 samples. You can raise several hypotheses else around those numbers, but please test them first and try to isolate the problem. I believe you have the same system as mine (one 8GB GPU, free TPU on colab, 32GB RAM) so you can try to make some simple model bigger, more complicated, at the production level,... to see if the issue persits.

@PaoloRanzi81
Copy link

@fehiepsi
I can feel the pain of the open-source software.

  1. So far I have always used progress_bar=False . I believed that was the fastest version by reading the documentation. Anyway, I will try both options now.

  2. My naive understanding of chain_method="sequential" with 1 GPU is that each .run(…) corresponds to running a single MCMC chain. When Numpyro finishes 1 single MCMC chain then it starts the next one. At the end Numpyro merges all those partial results into one single data frame.
    Therefore I was thinking that Numpyro by using chain_method="sequential" should cleverly unload the MCMC trace of each chain from RAM and store it on Hard Disk as soon as a single chain has been completed. This would have helped freeing up RAM memory once each chain ends.
    Unfortunately, the monotonic increase of RAM hints that's not the case: at each .run(…) cycle no RAM gets freed up.

  3. Ok. It looks like that you did notice a small increase of RAM with lax.map and 1 GPU, as well as me.
    You should definitely have to check whether you see the same RAM increase by using _laxmap . I do believe when you make your model more complex (say 200 latent variables or more) + you increase the total MCMC samples a lot + you run for 3 Hours you should see by naked eye that the RAM memory is going up to the roof.
    It is not a small increase. It’s a huge one. It is not a 4% increase which stays constant. At the end of the 3 Hrs, it is a 80% increase. And from the beginning it increases monotonically and it does not stays constant as you mentioned.
    As a confirmation, when using 1 GPU I can clearly see the steep “linear regression” with the RAM usage by using the monitoring tools provided by both GCP Compute Engine and AWS EC2. A the end of the 3 Hrs it almost reaches 39 GB of RAM consumption. When the 39 GB of RAM threshold is crossed, then the machine shuts down automatically.

  4. My system at the moment is 1 GPU Nvidia Tesla 16 GB.

  5. Probably I did some confusion with PyMC’s terminology. What I was suggesting is to try a TPU with chain_method=“parallel” + num_warmup= 55250 + num_samples=1000 + num_chains=8 .
    As a terminological conversion, in PyMC3 num_warmup = tune, num_samples = draws, num_chains = chains.
    The formula used internally by PyMC3 is: (draws + tune) * chains = total MCMC samples.
    In my case production case is: (55250+1000)*8 = 450000 MCMC samples.

Unfortunately with Numpyro and chain_method=“vectorized” + 1 GPU I could not do better than
(5375+1000)*8 = 51000 MCMC samples . More MCMC samples than that a computer crash was guaranteed because of the nasty OOM error. Even worse with chain_method=“sequential” : with (5375+1000)*2 = 12750 MCMC samples the RAM blows off (which means it crosses the 39 GB threshold).

@fehiepsi
Copy link
Member

fehiepsi commented Jan 28, 2021

@PaoloRanzi81 I created #899 to track down the issue because the previous issues in this topic have been resolved, and I don't want to send unnecessary notifications for other users. Please follow-up in the new issue that I created, with might catch more eyes from other users.

Some of my thoughts:

  • re vectorized: It might be expected, I am not sure what we can do to resolve the memory issue there. In addition, it is an experimental feature, please use it with caution on performance. Theoretically, that method will be slow. It might be fast in situations such as drawing a thousand of chains on simple models. I won't recommend using it to your problem. If your model is complicated, I won't recommend using GPU to draw a lot of samples at all because GPU is designed for parallel computation, not sequential computation.
  • re TPU: you can try it yourself I believe. You can use the code and link that I commented above, which shows how to use TPU on colab. In the gist I posted above, I used parallel method to draw 8 chains with 100,000 samples per chain. That number is ~2x larger than what you requested.
  • re increasing latent dimension from 100->200, increasing the total number of samples....: Could you test it on the simple model in that gist and let me know? I already did running 2 chains with 100,000 samples per chain with progress_bar=True in my last comment. I also did increase the latent dimension to 1000 and still see constant memory for a while, so I stopped in the middle of the process. Probably it is a device issue.
  • re free RAM resource: I don't know. I have little knowledge on what is the expected behavior when running things in GPU. Your explanation makes sense but I am not sure if a JAX program will behave that way.
  • re storing samples in hard disk: I think you can do it if you want. I'm not sure if it should be supported directly in NumPyro. It might raise more questions and issues than benefits.
  • re progress_bar=False: it will speed up your program in many cases, but not always. Please make a PR or FR to revise the docs if it makes you confused. As I said in my last comment, the speed gain is usually neglegible on GPU.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
performance question Further information is requested usability
Projects
None yet
Development

No branches or pull requests

5 participants