Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds basic infrastructure with prototype code for PM pipeline #9

Merged
merged 10 commits into from
May 31, 2022

Conversation

EiffL
Copy link
Member

@EiffL EiffL commented May 18, 2022

This PR adds some basic infra, like adds a setup.py. It also adds some code based on jaxpm to generate density planes from an nbody simulation, which can be used in a forward model.

A demo notebook is provided here
It generates DC2-like maps:
image

This provides a few first elements towards solving #8

@EiffL EiffL requested a review from Supranta May 18, 2022 14:33
bpcosmo/pm.py Outdated Show resolved Hide resolved
bpcosmo/pm.py Outdated Show resolved Hide resolved
Co-authored-by: Alexandre Boucaud <[email protected]>
bpcosmo/pm.py Outdated Show resolved Hide resolved
bpcosmo/pm.py Outdated

# Extract the lensplanes
density_planes = []
for i in range(len(a_center)):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for i in range(len(a_center)):
for i in range(n_lens):

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

line 123: is it the same 0.01 as in line 84? If so, maybe worth defining a new variable?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indeed, done!

Copy link
Contributor

@aboucaud aboucaud left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking good, but please homogenise the code indentation in the file.

bpcosmo/pm.py Outdated
dz = density_plane_width
plane = density_plane(res[0][::-1][i],
nc,
(i+0.5)*density_plane_width/box_size[-1]*nc[-1],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the "+0.5"? Some corner assignment scheme?

density_plane_smoothing: Gaussian scale of plane smoothing
box_size: [sx,sy,sz] size in Mpc/h of the simulation volume
nc: number of particles/voxels in the PM scheme
neural_spline_params: optional parameters for neural correction of PM scheme
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need a reference cited for the neural spline?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes indeed, unfortunately we don't have one yet ^^' this is some very recent trick that @dlanzieri is still working on. We might have an arxiv ref in few weeks

bpcosmo/pm.py Outdated
neural_spline_params: optional parameters for neural correction of PM scheme
Returns:
list of [r, a, plane], slices through the lightcone along with their
comoving distance and scale factors.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Being pedantic here, but maybe it's worth clarifying: "list of [r, a, plane], slices through the lightcone along with their comoving distance (r) and scale factors (a), with 'plane' being the index of the density slice.", or something along those lines.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

plane would range from 0 to nc[0]?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep thank you, added more clarification

bpcosmo/pm.py Outdated
a_center = jc.background.a_of_chi(cosmology, r_center)

# Create a small function to generate the matter power spectrum
k = jnp.logspace(-4, 1, 128)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why 128 bins?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've upped to to 256 to be safe. These are points used to interpolate the matter power spectrum, needed when building initial conditions. It shouldn't be very sensitive to this parameter choice.

bpcosmo/pm.py Outdated

# Initial displacement
cosmology._workspace = {} # FIX ME: this a temporary fix
dx, p, f = lpt(cosmology, initial_conditions, particles, 0.01)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where does 0.01 come from?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes, 0.01 is the the initial scale factor at which we initialize the simulation.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added explicitly the keyword

bpcosmo/pm.py Outdated
@jax.jit
def neural_nbody_ode(state, a, cosmo, params):
"""
state is a tuple (position, velocities)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

velocities in km/s?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hummm yeah good question... I think what I added there is correct

bpcosmo/pm.py Outdated
# Computes gravitational potential
pot_k = delta_k * laplace_kernel(kvec) * longrange_kernel(kvec, r_split=0)

# Apply a correction filter
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need this correction? what filter is this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the magic sauce from this spline correction function that @dlanzieri is building. I hope we'll soon have a paper to describe it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it corrects power on small scales, to effectively increase the resolution of the simulation

@EiffL EiffL requested a review from elts6570 May 31, 2022 12:42
@EiffL
Copy link
Member Author

EiffL commented May 31, 2022

I've tried to address all your comments @elts6570, let me know what you think and if you approve this PR :-)

@elts6570
Copy link
Collaborator

@EiffL looks great! looking forward to the magic splines paper! :)

@EiffL EiffL merged commit 69e82a9 into main May 31, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants