Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Matter power spectrum as function of cosmology (most likely emulator, or E.H.) #2

Open
Tracked by #1
EiffL opened this issue Apr 13, 2022 · 16 comments
Open
Tracked by #1
Assignees

Comments

@EiffL
Copy link
Member

EiffL commented Apr 13, 2022

This issue is to discuss the particular choice for the matter power spectrum that will be an input to all forward pipelines.

Here are a question to get us started:
@Supranta @nataliaporqueres : What are your options in terms of differentiable matter power spectra?

With @jecampagne and @dlanzieri we are still just using Eisenstein & Hu in jax-cosmo and flowpm respectively, but we can adapt to whatever other solution you are currently using.

@EiffL EiffL mentioned this issue Apr 13, 2022
7 tasks
@EiffL EiffL self-assigned this Apr 13, 2022
@nataliaporqueres
Copy link
Collaborator

In BORG, we have Eisenstein-Hu and CLASS.

@EiffL
Copy link
Member Author

EiffL commented Apr 13, 2022

Ok great, thanks @nataliaporqueres :-) but I guess you can't take the derivatives of CLASS, do you have a CLASS emulator? or do you use some trick to get the derivatives by finite differences from CLASS?

@nataliaporqueres
Copy link
Collaborator

We only need the derivative with respect to the initial conditions delta^{IC} because the HMC is only used to sample the density field. We use a slice sampler for the cosmological parameters, which doesn't involve derivatives.

@EiffL
Copy link
Member Author

EiffL commented Apr 14, 2022

Ah! Interesting! I hadn't realized that, sorry. Very interesting :-)

@Supranta
Copy link
Contributor

Supranta commented Apr 19, 2022

I have a PCE emulator as well as a GP emulator for power spectrum.
I am currently adapting my GP emulator to work with tinygp, which is based on jax. So I can add it here.

@Supranta Supranta self-assigned this Apr 19, 2022
@jecampagne
Copy link

I have setup this Jax-CLASS emulator based on the original work by the authors of https://arxiv.org/pdf/2105.02256.pdf

@EiffL
Copy link
Member Author

EiffL commented May 31, 2022

This is great @jecampagne :-D Only thing is your repo is private, people will probably not be able to see the content, would you be ok with making it public?

@jecampagne
Copy link

jecampagne commented Jun 19, 2022

Proposal fro new set of cosmologies:
I propose the following primary parameters (1000 latin HC points) ranges
‘Omega_cdm’: [0.1, 0.5]
‘Omega_b’: [0.04, 0.06]
‘sigma8’: [0.7, 0.9]
‘n_s’: [0.87, 1.06]
‘h’: [0.55, 0.8]
below are the distributions of these variables (first line), with derived parameters (second line). The red vertical line is the fiducial values.
image

And I will restrict the redshift range into [0, 3.5], for "k" the range is [5e-4, 50] h/Mpc by the high bound will be closely link to dP/P accuracy (ie. relative difference wrt CLASS computation).

@jecampagne
Copy link

jecampagne commented Jun 22, 2022

Here a result with the new training sets and the retrianing of the Models with the new scheme of building blocs, and taking care of the CLASS k_max_h_by_Mpc parameter which influences the Pk computations stability.
The figure below is using the Planck15 Jax-cosmo parameters which is not a training cosmo for the emulator
image

Comparaison with jax-cosmo for the jacobian and the vmap, notice that at the scale of the figure one cannot appreciate the diffrences at the level of few %.
image
image

@jecampagne
Copy link

jecampagne commented Jun 22, 2022

In fact the 1D and 2D interpolation in the (k,z) grid is performed (and constrained) to be done over a regular grid in (log scale for the k, and linear scale for the redshift z) and ideally I would sample differently to get better accuracy in the interval k=[0.1-1] h/Mpc. But the bottleneck is certainly the cosmo param sampling at the end of the day.

@jecampagne
Copy link

jecampagne commented Jun 24, 2022

I would like to share something that is quite stupid (sorry this is not great ML Bayesian theory): it concerns the time spent to load the parameters of the emulators.

In the current version, I need to load : 2x120 (PkLin and PkNLin at z=0) + 2x120x20 (the fucntions P(k,z)/P(k,z=0) both Linear and Non Linear) ~ 5200 GPs. The total amount of data is 288MB (not a big deal, isn't it)

Now depending where are located the files containing the parameters and the link Bandwidth, it can take (here at the Computing center at Lyon, if the data are in the system cache

Growth end-start (sec) 10.849421180784702
Pk Lin end-start (sec) 0.5452490821480751
Pk Non Lin end-start (sec) 0.541755635291338
Q func end-start (sec) 11.04175541549921

but it can takes 10 times more if the data are not in the system cache. End in Google Collab, the data are on the "MyDrive" and it has taken 1h to be loaded on a notebook! But two days after it has taken 30sec...

Does someone can help to try to speed this boring process?

Notice that after performing the computation of the Linear & Non linear Pk

pk_linear_interp = emu.linear_pk(cosmo_jax, k_star,z_star)
pk_nonlin_interp = emu.nonlinear_pk(cosmo_jax,k_star, z_star)

with 1200 k-values (k_star) for 4 differents redshits (z_start) the XLA times can vary abd last 10min. After the execution is about 15-30ms.

@jecampagne
Copy link

jecampagne commented Jun 26, 2022

I have refactoring the code (not yet public) which allow a jitted version of

k_star = jnp.geomspace(5e-4, 50, 1200, endpoint=True) #h/Mpc
z_star = jnp.array([0.,1., 2., 3.])
pk_linear_interp = emu.linear_pk(cosmo_jax, k_star,z_star)

The XLA compilation takes 400sec while the execution after takes 25 ms. Now, the XLA may perhaps be speed up using more JAX idioms.

@jecampagne
Copy link

jecampagne commented Jun 30, 2022

I succeed to speed up the data loading thanks to parallelized code. It takes 25sec @ CC and Google Colab. Of course if the bandwidth is greater then this overhead can be reducted. Concerning JAX/JIT I'm making an exhaustive code review and XLA tweeking. But, the work is not yet satisfactory. If no JITization done the emulator is althought faster than jax-cosmo. If JIT is turn on, the XLA compilation can take several minutes, but then the emulator is at the level of millisec.

@jecampagne
Copy link

jecampagne commented Jul 4, 2022

With the help of XLA experts, I proceed to a code review to use PyTree as well as JIT with parsimony.
I manage to get a XLA compilation ~8sec and an execution at the level of 1.5sec for 1200x4 (k*,z*) evaluations both for linear and non linear Pk. The emulator is now faster than the current jax_cosmo lib for the same types of computations (eg. vmap & jacfwd). The Git repos has been updated accordingly.

@jecampagne
Copy link

jecampagne commented Jul 5, 2022

I finally succeed to boost the prediction speed at a moderate price for XLA compilation time. So, it yields with the new committed version of Jemu

  • 20sec to load the GPs parameters
  • 50 sec for XLA compilation
  • 4-5 ms to predict (Linear or Non Linear) Pk for 1200 k values and different redshifts at once.
    Of course if new XLA compil is needed, one waits for new 50sec but after JIT power is there.

@jecampagne
Copy link

I write some scripts to

  • build new cosmological parameter dataset,
  • run CLASS to produce the PkLin,non-Lin, and the two other functions (Growth k-scale, and similar function for Pk non-Lin),
  • and finally the different building blocks training

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants