This ssax
repository demonstrates the proof of concept for the Sinkhorn Step - a batch gradient-free optimizer for highly non-convex objectives in JAX. ssax
is heavily inspired by the code structure of OTT-JAX to utilize most of its linear solvers, enabling the users to easily switch between different solver flavors.
NOTE: This repository implements the Sinkhorn Step optimizer (in JAX) as a general-purpose standalone solver for non-convex optimization problems. The MPOT trajectory optimizer using Sinkhorn Step (in PyTorch) is released in this repository mpot.
This work has been accepted to NeurIPS 2023. Please find the pre-print here:
Simply activate your conda/Python environment and run
pip install -e .
Please install JAX with CUDA support if you want to run the code on GPU for more performance.
An example script is provided in scripts/example.py
.
For testing Sinkhorn Step with various synthetic functions, run the following script with hydra
settings:
python scripts/run.py experiment=ss-al
and find result animations in the logs/
folder. You can replace the tag experiment=<exp-filename>
with filenames found in configs/experiment
folder. The current available optimization experiments are:
ss-al
: Ackley function in 2Dss-al-10d
: Ackley function in 10Dss-bk
: Bukin function in 2Dss-dw
: DropWave function in 2Dss-eh
: EggHolder function in 2Dss-ht
: Hoelder Table function in 2Dss-lv
: Levy function in 2Dss-rb
: Rosenbrock function in 2Dss-rg
: Rastrigin function in 2Dss-st
: Styblinski-Tang function in 2Dss-st-10d
: Styblinski-Tang function in 10D
Note: For tuning new settings, the most sensitive hyperparameters are step_radius
, probe_radius
, entropic regularization scalar ent_epsilon
and the step-annealing scheme epsilon_scheduler
parameters. You can play around with these parameters together with the other hyperparameters with synthetic functions to get a feeling of how they affect the optimization. For the 10D experiments, the plots are projected to the first 2 dimensions for visualization.
We also add some benchmarks on gradient approximation experiments based on cosine similarity between the Sinkhorn Step and the true gradient, over outer iterations and over entropic regularization on the Sinkhorn distance. We turn off step annealing for benchmarking purpose. The current available gradient approximation experiments are:
ss-al-cosin-sim
: Ackley function in 10Dss-st-cosin-sim
: Styblinski-Tang function in 10D
To run them:
python scripts/benchmark_cosin_similarity_single.py experiment=ss-st-cosin-sim num_seeds=20
If you found this work useful, please consider citing this reference:
@inproceedings{le2023accelerating,
title={Accelerating Motion Planning via Optimal Transport},
author={Le, An T. and Chalvatzaki, Georgia and Biess, Armin and Peters, Jan},
booktitle={Advances in Neural Information Processing Systems (NeurIPS)},
year={2023}
}
- The OTT-JAX documentation for more details on the linear solvers.
- This wonderful library of synthetic test functions from Sonja & Derek (SFU).