TensorFlow 2.X reimplementation of PonderNet: Learning to Ponder, Andrea Banino, Jan Balaguer, Charles Blundell.
In standard neural networks the amount of computation used grows with the size of the inputs, but not with the complexity of the problem being learnt. To overcome this limitation we introduce PonderNet, a new algorithm that learns to adapt the amount of computation based on the complexity of the problem at hand. PonderNet learns end-to-end the number of computational steps to achieve an effective compromise between training prediction accuracy, computational cost and generalization. On a complex synthetic problem, PonderNet dramatically improves performance over previous adaptive computation methods and additionally succeeds at extrapolation tests where traditional neural networks fail. Also, our method matched the current state of the art results on a real world question and answering dataset, but using less compute. Finally, PonderNet reached state of the art results on a complex task designed to test the reasoning capabilities of neural networks.
The input of the parity task is a vector with 0's 1's and −1's. The output is the parity of 1's - one if there is an odd number of 1's and zero otherwise. The input is generated by making a random number of elements in the vector either 1 or −1's.
Performance on the parity task. a) Interpolation. Top: accuracy for both PonderNet (blue) and ACT (orange). Bottom: number of ponder steps at evaluation time. Error bars are calculated over 10 random seeds. b) Extrapolation. Top: accuracy for both PonderNet (blue) and ACT (orange). Bottom: number of ponder steps at evaluation time. Error bars are calculated over 10 random seeds. c) Total number of compute steps calculated as the number of actual forward passes performed by each network. Blue is PonderNet, Green is ACT and Orange is an RNN without adaptive compute.
Clone the repo and install necessary packages
git clone https://github.com/EMalagoli92/PonderNet-TensorFlow.git
pip install -r requirements.txt
Tested on Ubuntu 20.04.4 LTS x86_64, python 3.9.7.
Train a PonderNet on Parity Task
python __main__.py
PonderNet (Official PyTorch Implementation)
@misc{banino2021pondernet,
title={PonderNet: Learning to Ponder},
author={Andrea Banino and Jan Balaguer and Charles Blundell},
year={2021},
eprint={2107.05407},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
This work is made available under the MIT License