+ Summary
+ diffopt is a Python package which
+ facilitates in the optimization of data-parallelized, differentiable
+ models using the Jax
+ (Bradbury
+ et al., 2018) framework. It is composed of three subpackages,
+ multigrad, kdescent, and
+ multiswarm. Leveraging MPI (Message Passing
+ Interface), multigrad efficiently sums and
+ propagates gradients of custom-defined summary statistics across
+ processors and computing nodes. kdescent
+ utilizes mini-batched kernel density estimates to perform stochastic
+ gradient descent to fit a full model distribution to an N-dimensional
+ training dataset. A massively parallelizable implementation of
+ particle swarm optimization (PSO) is provided by
+ multiswarm, enabling global optimization of
+ even high-dimensional, non-convex loss surfaces. Our simple yet
+ flexible design makes these methods applicable to a wide variety of
+ problems requiring solutions scalable to large amounts of data through
+ both gradient- and non-gradient-based optimization techniques. Visit
+ our
+ documentation
+ page to learn the usage.
+
+
+ Statement of Need
+ In and beyond the field of cosmology, parameterized models can
+ describe complex systems, provided that the parameters have been tuned
+ adequately to fit the model to observational data. Fitting
+ capabilities can be increased dramatically by gradient-based
+ techniques, particularly in high-dimensional parameter spaces.
+ Existing gradient descent tools in Jax do not inherently support
+ data-parallelism with MPI, creating a speed and memory bottleneck for
+ such computations.
+ multigrad addresses this need by providing
+ an easy-to-use interface for implementing data-parallelized models. It
+ handles the MPI reductions as well as the mathematical complexities
+ involved in propagating chain rules required to compute the gradient
+ of the loss, which is a function of parallelized summary statistics,
+ which are in turn functions of the model parameters. At the same time,
+ it is very flexible in that it allows users to define their own
+ functions to compute their summary statistics and loss. As a result,
+ this package can enable scalability through parallelization to the
+ optimization routine of nearly any big-data model.
+ kdescent and multiswarm
+ each provide powerful fitting tools which are fully compatible with
+ the parallelization framework laid out by
+ multigrad.
+ Past efforts have already been made towards parallelization of Jax
+ (mpi4jax,
+ Häfner
+ & Vicentini, 2021), parallel gradient descent (e.g.,
+ Gray
+ et al., 2019), and parallel PSO
+ (Blank
+ & Deb, 2020;
+ Li
+ & Wada, 2005). Our approach combines many of these features
+ and more into one easy-to-use, documented Python module with all the
+ tools to optimize arbitrarily complex models that the user has
+ implemented in the Jax framework. Additionally, while the fundamental
+ MPI reductions available through mpi4jax are
+ generally sufficient, our multigrad procedure
+ provides significant convenience for problems in which complex summary
+ statistics are computed in parallel before being applied to a
+ differentiable loss function.
+
+
+ Method
+
+ multigrad
+ multigrad allows the user to implement a
+ loss term, which is a function of summary statistics, which are
+ functions of parameters,
+
+ L(y→(x→))
+ where the summary statistics are summed over multiple MPI-linked
+ processes:
+
+ y→=∑iy→(i)
+ where
+
+ i
+ is the index of each process. In this section, we will derive the
+ gradient of the loss
+
+ ∇→L
+ with respect to the parameters and as a sum of terms that each
+ process can compute independently.
+ We will begin from the definition of the multivariate chain
+ rule,
+
+
+ ∂L∂xj=∑k∂L∂yk∂yk∂xj
+ where
+
+ ∂yk
+ =
+
+ ∑i∂yk(i).
+ By pulling out the MPI summation over
+
+ i,
+
+
+ ∂L∂xj=∑i∑k∂L∂yk∂yk(i)∂xj
+ and by rewriting this as vector-matrix multiplication,
+
+
+ ∇x→L=∑i(∇y→L)TJ(i)
+ we can clearly identify that each process has to perform a
+ vector-Jacobian product (VJP), where
+
+ J(i)
+ is the Jacobian matrix such that
+
+ Jkj(i)=∂yk(i)∂xj.
+ Fortunately, this is a computation that Jax can perform very
+ efficiently, without the need to explicitly calculate the full
+ Jacobian matrix by making use of the
+ jax.vjp
+ feature, saving us orders of magnitude of time and memory
+ requirements.
+
+
+ kdescent
+ Mini-batching techniques often compute the loss function with
+ only a small subset of the training data taken into account. In
+ kdescent, the density of the full training
+ dataset is measured around a “mini-batched” sample of kernel
+ centers, which are drawn from points in the training data. With each
+ iteration of stochastic gradient descent, a new sample of (20 by
+ default) kernels is selected at positions
+
+
+ μ→k
+ for each kernel
+
+ k.
+ Using the compare_kde_counts method, the
+ “true” and “model” counts are each computed around each kernel using
+ the same equation below, where
+
+ xi
+ is the
+
+ point in the training data or model data, respectively:
+
+
+ Nk=∑i𝒩(x→i|μ→k,Σ)
+ where
+
+ 𝒩
+ is the multivariate-normal distribution with mean
+
+
+ μ→k
+ and covariance matrix
+
+ Σ
+ (where the covariance is calculated using Scott’s rule for kernel
+ density estimation of the training dataset; Scott
+ (1992)).
+ It is then up to the user to define their own loss function
+ comparing the counts of
+
+ to
+ .
+ Note that these are extrinsic quantities (as is necessary to be
+ parallelizable through multigrad) which can
+ be reduced to intrinsic quantities for PDF-level comparisons by
+ simply dividing by the total number of training and model data,
+ respectively.
+ The analogous compare_fourier_counts
+ method can provide additional loss terms relating to differences in
+ the empirical characteristic function (ECF; Cramer
+ (1954)).
+ It is evaluated at a random sample of (20 by default) Fourier-space
+ positions,
+
+ x̃→k,
+ for both the “true” and “model” Fourier counts:
+
+
+ Ñk=∑iexp(ix̃→k⋅x→i).
+
+
+ multiswarm
+ Particle swarm optimization (PSO; Kennedy & Eberhart
+ (1995))
+ is a highly exploratory fitting algorithm in which a set of (100 by
+ default) particles are initialized with randomized velocities and
+ positions with Latin-Hypercube spacing over the loss function’s
+ parameter space. Each particle has an inertial weight
+ (
+
+ wI=1
+ by default), a cognitive weight, (
+
+ wC=0.21
+ by default), and a social weight, (
+
+ wS=0.07
+ by default). The default parameters have been hand-tuned to optimize
+ parameter exploration performed by 100 particles before converging
+ over roughly 100 time steps in a 4D Ackley loss function
+ demonstrated
+ in our documentation.
+ Within each PSO iteration: (1) Each particle’s position is
+ updated according to its current velocity
+
+
+ xi+1=xi+vi.
+ (2) Positions and velocities are then reflected accordingly across
+ any axes in which they have left the boundaries, if applicable. (3)
+ Finally, the particle’s velocity is slightly pulled in the direction
+ of its personal best
+
+ and global best
+
+ loss found, according to the following equation:
+
+ The multiswarm implementation of PSO
+ allows users to conveniently distribute the loss function
+ computations performed by each particle across MPI ranks. Particles
+ are evenly distributed across all ranks by default, but users
+ needing further control can provide a custom MPI communicator
+ object, and/or specify the ranks_per_particle
+ argument to manually control intra-particle parallelization.
+
+
+