Skip to content

mir-group/pytorch_runstats

Repository files navigation

pytorch_runstats

Running/online statistics for PyTorch.

Documentation Status

torch_runstats implements memory-efficient online reductions on tensors.

Notable features:

  • Arbitrary sample shapes beyond single scalars
  • Reduction over arbitrary dimensions of each sample
  • "Batched"/"binned" reduction into multiple running tallies using a per-sample bin index. This can be useful, for example, in accumulating statistics over samples by some kind of "type" index or for accumulating statistics per-graph in a pytorch_geometric-like batching scheme. (This feature is similar to torch_scatter.)
  • Option to ignore NaN values with correct sample counting.

Note: the implementations currently heavily uses in-place operations for peformance and memory efficiency. This probably doesn't play nice with the autograd engine — this is currently likely the wrong library for accumulating running statistics you want to backward through. (See TorchMetrics for a possible alternative.)

For more information, please see the docs.

Install

torch_runstats requires PyTorch.

The library can be installed from PyPI:

$ pip install torch_runstats

The latest development version of the code can also be installed from git:

$ git clone https://github.com/mir-group/pytorch_runstats

and install it by running

$ cd torch_runstats/
$ pip install .

You can run the tests with

$ pytest tests/

License

pytorch_runstats is distributed under an MIT license.