Running/online statistics for PyTorch.
torch_runstats
implements memory-efficient online reductions on tensors.
Notable features:
- Arbitrary sample shapes beyond single scalars
- Reduction over arbitrary dimensions of each sample
- "Batched"/"binned" reduction into multiple running tallies using a per-sample bin index.
This can be useful, for example, in accumulating statistics over samples by some kind of "type" index or for accumulating statistics per-graph in a
pytorch_geometric
-like batching scheme. (This feature is similar totorch_scatter
.) - Option to ignore NaN values with correct sample counting.
Note: the implementations currently heavily uses in-place operations for peformance and memory efficiency. This probably doesn't play nice with the autograd engine — this is currently likely the wrong library for accumulating running statistics you want to backward through. (See TorchMetrics for a possible alternative.)
For more information, please see the docs.
torch_runstats
requires PyTorch.
The library can be installed from PyPI:
$ pip install torch_runstats
The latest development version of the code can also be installed from git
:
$ git clone https://github.com/mir-group/pytorch_runstats
and install it by running
$ cd torch_runstats/
$ pip install .
You can run the tests with
$ pytest tests/
pytorch_runstats
is distributed under an MIT license.