Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds basic infrastructure with prototype code for PM pipeline #9

Merged
merged 10 commits into from
May 31, 2022
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added bpcosmo/__init__.py
Empty file.
144 changes: 144 additions & 0 deletions bpcosmo/pm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
# This module contains utility functions to run a forward simulation
import jax
import jax.numpy as jnp

import jax_cosmo as jc

import numpyro
import numpyro.distributions as dist

from jax.experimental.ode import odeint
from jaxpm.pm import lpt, make_ode_fn
from jaxpm.kernels import fftk
from jaxpm.lensing import density_plane
import haiku as hk

from jaxpm.painting import cic_paint, cic_read
from jaxpm.kernels import gradient_kernel, laplace_kernel, longrange_kernel
from jaxpm.nn import NeuralSplineFourierFilter

__all__ = [
'get_density_planes',
]

model = hk.without_apply_rng(hk.transform(lambda x,a : NeuralSplineFourierFilter(n_knots=16, latent_size=32)(x,a)))

def linear_field(mesh_shape, box_size, pk):
"""
Generate initial conditions.
"""
kvec = fftk(mesh_shape)
kmesh = sum((kk / box_size[i] * mesh_shape[i])**2 for i, kk in enumerate(kvec))**0.5
pkmesh = pk(kmesh) * (mesh_shape[0] * mesh_shape[1] * mesh_shape[2]) / (box_size[0] * box_size[1] * box_size[2])

field = numpyro.sample('initial_conditions',
dist.Normal(jnp.zeros(mesh_shape),
jnp.ones(mesh_shape)))

field = jnp.fft.rfftn(field) * pkmesh**0.5
field = jnp.fft.irfftn(field)
return field

def get_density_planes(cosmology,
density_plane_width=100., # In Mpc/h
density_plane_npix=256, # Number of pixels
density_plane_smoothing=3., # In Mpc/h
box_size=[400.,400.,4000.], # In Mpc/h
nc=[64,64,640],
EiffL marked this conversation as resolved.
Show resolved Hide resolved
neural_spline_params=None
):
"""Function that returns tomographic density planes
for a given cosmology from a lightcone.

Args:
cosmology: jax-cosmo object
density_plane_width: width of the output density slices
density_plane_npix: size of the output density slices
density_plane_smoothing: Gaussian scale of plane smoothing
box_size: [sx,sy,sz] size in Mpc/h of the simulation volume
nc: number of particles/voxels in the PM scheme
neural_spline_params: optional parameters for neural correction of PM scheme
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need a reference cited for the neural spline?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes indeed, unfortunately we don't have one yet ^^' this is some very recent trick that @dlanzieri is still working on. We might have an arxiv ref in few weeks

Returns:
list of [r, a, plane], slices through the lightcone along with their
comoving distance and scale factors.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Being pedantic here, but maybe it's worth clarifying: "list of [r, a, plane], slices through the lightcone along with their comoving distance (r) and scale factors (a), with 'plane' being the index of the density slice.", or something along those lines.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

plane would range from 0 to nc[0]?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep thank you, added more clarification

"""
# Planning out the scale factor stepping to extract desired lensplanes
n_lens = int(box_size[-1] // density_plane_width)
r = jnp.linspace(0., box_size[-1], n_lens+1)
r_center = 0.5*(r[1:] + r[:-1])
a_center = jc.background.a_of_chi(cosmology, r_center)

# Create a small function to generate the matter power spectrum
k = jnp.logspace(-4, 1, 128)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why 128 bins?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've upped to to 256 to be safe. These are points used to interpolate the matter power spectrum, needed when building initial conditions. It shouldn't be very sensitive to this parameter choice.

pk = jc.power.linear_matter_power(cosmology, k)
pk_fn = lambda x: jc.scipy.interpolate.interp(x.reshape([-1]), k, pk).reshape(x.shape)

# Create initial conditions
initial_conditions = linear_field(nc, box_size, pk_fn)

# Create particles
particles = jnp.stack(jnp.meshgrid(*[jnp.arange(s) for s in nc]),axis=-1).reshape([-1,3])

# Initial displacement
cosmology._workspace = {} # FIX ME: this a temporary fix
dx, p, f = lpt(cosmology, initial_conditions, particles, 0.01)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where does 0.01 come from?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes, 0.01 is the the initial scale factor at which we initialize the simulation.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added explicitly the keyword


@jax.jit
def neural_nbody_ode(state, a, cosmo, params):
"""
state is a tuple (position, velocities)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

velocities in km/s?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hummm yeah good question... I think what I added there is correct

"""
pos, vel = state

kvec = fftk(nc)

delta = cic_paint(jnp.zeros(nc), pos)

delta_k = jnp.fft.rfftn(delta)

# Computes gravitational potential
pot_k = delta_k * laplace_kernel(kvec) * longrange_kernel(kvec, r_split=0)

# Apply a correction filter
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need this correction? what filter is this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the magic sauce from this spline correction function that @dlanzieri is building. I hope we'll soon have a paper to describe it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it corrects power on small scales, to effectively increase the resolution of the simulation

if params is not None:
kk = jnp.sqrt(sum((ki/jnp.pi)**2 for ki in kvec))
pot_k = pot_k *(1. + model.apply(params, kk, jnp.atleast_1d(a)))

# Computes gravitational forces
forces = jnp.stack([cic_read(jnp.fft.irfftn(gradient_kernel(kvec, i)*pot_k), pos)
for i in range(3)],axis=-1)

forces = forces * 1.5 * cosmo.Omega_m

# Computes the update of position (drift)
dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * vel

# Computes the update of velocity (kick)
dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces
EiffL marked this conversation as resolved.
Show resolved Hide resolved

return dpos, dvel

# Evolve the simulation forward
res = odeint(neural_nbody_ode, [particles+dx, p],
jnp.concatenate([jnp.atleast_1d(0.01), a_center[::-1]]),
cosmology, neural_spline_params, rtol=1e-5, atol=1e-5)

# Extract the lensplanes
density_planes = []
for i in range(len(a_center)):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for i in range(len(a_center)):
for i in range(n_lens):

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

line 123: is it the same 0.01 as in line 84? If so, maybe worth defining a new variable?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indeed, done!

dx = box_size[0]/density_plane_npix
dz = density_plane_width
plane = density_plane(res[0][::-1][i],
nc,
(i+0.5)*density_plane_width/box_size[-1]*nc[-1],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the "+0.5"? Some corner assignment scheme?

width=density_plane_width/box_size[-1]*nc[-1],
plane_resolution=density_plane_npix,
smoothing_sigma=density_plane_smoothing/dx
)
density_planes.append({'r':r_center[i],
'a':a_center[i],
'plane': plane,
'dx': dx,
'dz': dz})

return density_planes
429 changes: 429 additions & 0 deletions notebooks/forward_model/LensingForwardModel.ipynb

Large diffs are not rendered by default.

File renamed without changes.
32 changes: 32 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from setuptools import find_packages
from setuptools import setup
from io import open

# read the contents of the README file
with open('README.md', encoding="utf-8") as f:
long_description = f.read()

setup(
name='BPCosmo',
description='Cosmology-level Bayesian Pipeline forward model',
long_description=long_description,
long_description_content_type='text/markdown',
author='LSST DESC',
url='https://github.com/LSSTDESC/bayesian-pipelines-cosmology',
license='MIT',
packages=find_packages(),
install_requires=['numpyro', 'jax', 'lenstools', 'dm-haiku', 'jaxpm', 'jax-cosmo'],
dependency_links=[
'https://github.com/DifferentiableUniverseInitiative/JaxPM/tarball/master#egg=jaxpm-0.0.1'
],
classifiers=[
'Development Status :: 2 - Pre-Alpha',
'Intended Audience :: Developers',
'Intended Audience :: Science/Research',
'License :: OSI Approved :: MIT License',
'Topic :: Scientific/Engineering :: Astronomy'
],
keywords='cosmology',
use_scm_version=True,
setup_requires=['setuptools_scm']
)