Skip to content

Commit

Permalink
Compute pairwise distances (#278)
Browse files Browse the repository at this point in the history
* Draft inter-individual distances

* Return vector norm in `compute_interindividual_distances`

* Add `compute_interkeypoint_distances`

* Refactor pairwise distances tests

* Use `scipy.spatial.distance.cdist`

* Add examples to docstrings

* Rename variables

* Update test function args + fix indentation

* Handle scalar and 1d dims

* Handle missing `core_dim`

* Refactor `cdist` and tests

* Fix docstrings

* Reorder functions + cleanup docs

* Reduce pairwise distances functions

* Mention examples of available distance metrics

* Update docstrings

* Require `pairs` in `compute_pairwise_distances`

* Raise error if there are no pairs to compute distances for

* Rename `core_dim` to `labels_dim`

* Spell out expected pairs in test

* Merge old `kinematics` file changes into new

* Rename `core_dim` to `labels_dim` in tests

* Validate dims in `compute_pairwise_distances`

* Apply suggestions from code review

Co-authored-by: Niko Sirmpilatze <[email protected]>

* Apply suggestions from code review

Co-authored-by: Niko Sirmpilatze <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: Niko Sirmpilatze <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Oct 29, 2024
1 parent b10896f commit 536de0e
Show file tree
Hide file tree
Showing 3 changed files with 540 additions and 3 deletions.
332 changes: 332 additions & 0 deletions movement/kinematics.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
"""Compute kinematic variables like velocity and acceleration."""

import itertools
from typing import Literal

import numpy as np
import xarray as xr
from scipy.spatial.distance import cdist

from movement.utils.logging import log_error
from movement.utils.vector import compute_norm
Expand Down Expand Up @@ -324,6 +326,336 @@ def compute_head_direction_vector(
)


def _cdist(
a: xr.DataArray,
b: xr.DataArray,
dim: Literal["individuals", "keypoints"],
metric: str | None = "euclidean",
**kwargs,
) -> xr.DataArray:
"""Compute distances between two position arrays across a given dimension.
This function is a wrapper around :func:`scipy.spatial.distance.cdist`
and computes the pairwise distances between the two input position arrays
across the dimension specified by ``dim``.
The dimension can be either ``individuals`` or ``keypoints``.
The distances are computed using the specified ``metric``.
Parameters
----------
a : xarray.DataArray
The first input data containing position information of a
single individual or keypoint, with ``time``, ``space``
(in Cartesian coordinates), and ``individuals`` or ``keypoints``
(as specified by ``dim``) as required dimensions.
b : xarray.DataArray
The second input data containing position information of a
single individual or keypoint, with ``time``, ``space``
(in Cartesian coordinates), and ``individuals`` or ``keypoints``
(as specified by ``dim``) as required dimensions.
dim : str
The dimension to compute the distances for. Must be either
``'individuals'`` or ``'keypoints'``.
metric : str, optional
The distance metric to use. Must be one of the options supported
by :func:`scipy.spatial.distance.cdist`, e.g. ``'cityblock'``,
``'euclidean'``, etc.
Defaults to ``'euclidean'``.
**kwargs : dict
Additional keyword arguments to pass to
:func:`scipy.spatial.distance.cdist`.
Returns
-------
xarray.DataArray
An xarray DataArray containing the computed distances between
each pair of inputs.
Examples
--------
Compute the Euclidean distance (default) between ``ind1`` and
``ind2`` (i.e. interindividual distance for all keypoints)
using the ``position`` data variable in the Dataset ``ds``:
>>> pos1 = ds.position.sel(individuals="ind1")
>>> pos2 = ds.position.sel(individuals="ind2")
>>> ind_dists = _cdist(pos1, pos2, dim="individuals")
Compute the Euclidean distance (default) between ``key1`` and
``key2`` (i.e. interkeypoint distance for all individuals)
using the ``position`` data variable in the Dataset ``ds``:
>>> pos1 = ds.position.sel(keypoints="key1")
>>> pos2 = ds.position.sel(keypoints="key2")
>>> key_dists = _cdist(pos1, pos2, dim="keypoints")
See Also
--------
scipy.spatial.distance.cdist : The underlying function used.
compute_pairwise_distances : Compute pairwise distances between
``individuals`` or ``keypoints``
"""
# The dimension from which ``dim`` labels are obtained
labels_dim = "individuals" if dim == "keypoints" else "keypoints"
elem1 = getattr(a, dim).item()
elem2 = getattr(b, dim).item()
a = _validate_labels_dimension(a, labels_dim)
b = _validate_labels_dimension(b, labels_dim)
result = xr.apply_ufunc(
cdist,
a,
b,
kwargs={"metric": metric, **kwargs},
input_core_dims=[[labels_dim, "space"], [labels_dim, "space"]],
output_core_dims=[[elem1, elem2]],
vectorize=True,
)
result = result.assign_coords(
{
elem1: getattr(a, labels_dim).values,
elem2: getattr(a, labels_dim).values,
}
)
# Drop any squeezed coordinates
return result.squeeze(drop=True)


def compute_pairwise_distances(
data: xr.DataArray,
dim: Literal["individuals", "keypoints"],
pairs: dict[str, str | list[str]] | Literal["all"],
metric: str | None = "euclidean",
**kwargs,
) -> xr.DataArray | dict[str, xr.DataArray]:
"""Compute pairwise distances between ``individuals`` or ``keypoints``.
This function computes the distances between
pairs of ``individuals`` (i.e. interindividual distances) or
pairs of ``keypoints`` (i.e. interkeypoint distances),
as determined by ``dim``.
The distances are computed for the given ``pairs``
using the specified ``metric``.
Parameters
----------
data : xarray.DataArray
The input data containing position information, with ``time``,
``space`` (in Cartesian coordinates), and
``individuals`` or ``keypoints`` (as specified by ``dim``)
as required dimensions.
dim : Literal["individuals", "keypoints"]
The dimension to compute the distances for. Must be either
``'individuals'`` or ``'keypoints'``.
pairs : dict[str, str | list[str]] or 'all'
Specifies the pairs of elements (either individuals or keypoints)
for which to compute distances, depending on the value of ``dim``.
- If ``dim='individuals'``, ``pairs`` should be a dictionary where
each key is an individual name, and each value is also an individual
name or a list of such names to compute distances with.
- If ``dim='keypoints'``, ``pairs`` should be a dictionary where each
key is a keypoint name, and each value is also keypoint name or a
list of such names to compute distances with.
- Alternatively, use the special keyword ``'all'`` to compute distances
for all possible pairs of individuals or keypoints
(depending on ``dim``).
metric : str, optional
The distance metric to use. Must be one of the options supported
by :func:`scipy.spatial.distance.cdist`, e.g. ``'cityblock'``,
``'euclidean'``, etc.
Defaults to ``'euclidean'``.
**kwargs : dict
Additional keyword arguments to pass to
:func:`scipy.spatial.distance.cdist`.
Returns
-------
xarray.DataArray or dict[str, xarray.DataArray]
The computed pairwise distances. If a single pair is specified in
``pairs``, returns an :class:`xarray.DataArray`. If multiple pairs
are specified, returns a dictionary where each key is a string
representing the pair (e.g., ``'dist_ind1_ind2'`` or
``'dist_key1_key2'``) and each value is an :class:`xarray.DataArray`
containing the computed distances for that pair.
Raises
------
ValueError
If ``dim`` is not one of ``'individuals'`` or ``'keypoints'``;
if ``pairs`` is not a dictionary or ``'all'``; or
if there are no pairs in ``data`` to compute distances for.
Examples
--------
Compute the Euclidean distance (default) between ``ind1`` and ``ind2``
(i.e. interindividual distance), for all possible pairs of keypoints.
>>> position = xr.DataArray(
... np.arange(36).reshape(2, 3, 3, 2),
... coords={
... "time": np.arange(2),
... "individuals": ["ind1", "ind2", "ind3"],
... "keypoints": ["key1", "key2", "key3"],
... "space": ["x", "y"],
... },
... dims=["time", "individuals", "keypoints", "space"],
... )
>>> dist_ind1_ind2 = compute_pairwise_distances(
... position, "individuals", {"ind1": "ind2"}
... )
>>> dist_ind1_ind2
<xarray.DataArray (time: 2, ind1: 3, ind2: 3)> Size: 144B
8.485 11.31 14.14 5.657 8.485 11.31 ... 5.657 8.485 11.31 2.828 5.657 8.485
Coordinates:
* time (time) int64 16B 0 1
* ind1 (ind1) <U4 48B 'key1' 'key2' 'key3'
* ind2 (ind2) <U4 48B 'key1' 'key2' 'key3'
The resulting ``dist_ind1_ind2`` is a DataArray containing the computed
distances between ``ind1`` and ``ind2`` for all keypoints
at each time point.
To obtain the distances between ``key1`` of ``ind1`` and
``key2`` of ``ind2``:
>>> dist_ind1_ind2.sel(ind1="key1", ind2="key2")
Compute the Euclidean distance (default) between ``key1`` and ``key2``
(i.e. interkeypoint distance), for all possible pairs of individuals.
>>> dist_key1_key2 = compute_pairwise_distances(
... position, "keypoints", {"key1": "key2"}
... )
>>> dist_key1_key2
<xarray.DataArray (time: 2, key1: 3, key2: 3)> Size: 144B
2.828 11.31 19.8 5.657 2.828 11.31 14.14 ... 2.828 11.31 14.14 5.657 2.828
Coordinates:
* time (time) int64 16B 0 1
* key1 (key1) <U4 48B 'ind1' 'ind2' 'ind3'
* key2 (key2) <U4 48B 'ind1' 'ind2' 'ind3'
The resulting ``dist_key1_key2`` is a DataArray containing the computed
distances between ``key1`` and ``key2`` for all individuals
at each time point.
To obtain the distances between ``key1`` and ``key2`` within ``ind1``:
>>> dist_key1_key2.sel(key1="ind1", key2="ind1")
To obtain the distances between ``key1`` of ``ind1`` and
``key2`` of ``ind2``:
>>> dist_key1_key2.sel(key1="ind1", key2="ind2")
Compute the city block or Manhattan distance for multiple pairs of
keypoints using ``position``:
>>> key_dists = compute_pairwise_distances(
... position,
... "keypoints",
... {"key1": "key2", "key3": ["key1", "key2"]},
... metric="cityblock",
... )
>>> key_dists.keys()
dict_keys(['dist_key1_key2', 'dist_key3_key1', 'dist_key3_key2'])
As multiple pairs of keypoints are specified,
the resulting ``key_dists`` is a dictionary containing the DataArrays
of computed distances for each pair of keypoints.
Compute the city block or Manhattan distance for all possible pairs of
individuals using ``position``:
>>> ind_dists = compute_pairwise_distances(
... position,
... "individuals",
... "all",
... metric="cityblock",
... )
>>> ind_dists.keys()
dict_keys(['dist_ind1_ind2', 'dist_ind1_ind3', 'dist_ind2_ind3'])
See Also
--------
scipy.spatial.distance.cdist : The underlying function used.
"""
if dim not in ["individuals", "keypoints"]:
raise log_error(
ValueError,
"'dim' must be either 'individuals' or 'keypoints', "
f"but got {dim}.",
)
if isinstance(pairs, str) and pairs != "all":
raise log_error(
ValueError,
f"'pairs' must be a dictionary or 'all', but got {pairs}.",
)
validate_dims_coords(data, {"time": [], "space": ["x", "y"], dim: []})
# Find all possible pair combinations if 'all' is specified
if pairs == "all":
paired_elements = list(
itertools.combinations(getattr(data, dim).values, 2)
)
else:
paired_elements = [
(elem1, elem2)
for elem1, elem2_list in pairs.items()
for elem2 in
(
# Ensure elem2_list is a list
[elem2_list] if isinstance(elem2_list, str) else elem2_list
)
]
if not paired_elements:
raise log_error(
ValueError, "Could not find any pairs to compute distances for."
)
pairwise_distances = {
f"dist_{elem1}_{elem2}": _cdist(
data.sel({dim: elem1}),
data.sel({dim: elem2}),
dim=dim,
metric=metric,
**kwargs,
)
for elem1, elem2 in paired_elements
}
# Return DataArray if result only has one key
if len(pairwise_distances) == 1:
return next(iter(pairwise_distances.values()))
return pairwise_distances


def _validate_labels_dimension(data: xr.DataArray, dim: str) -> xr.DataArray:
"""Validate the input data contains the ``dim`` for labelling dimensions.
This function ensures the input data contains the ``dim``
used as labels (coordinates) when applying
:func:`scipy.spatial.distance.cdist` to
the input data, by adding a temporary dimension if necessary.
Parameters
----------
data : xarray.DataArray
The input data to validate.
dim : str
The dimension to validate.
Returns
-------
xarray.DataArray
The input data with the labels dimension validated.
"""
if data.coords.get(dim) is None:
data = data.assign_coords({dim: "temp_dim"})
if data.coords[dim].ndim == 0:
data = data.expand_dims(dim).transpose("time", "space", dim)
return data


def _validate_type_data_array(data: xr.DataArray) -> None:
"""Validate the input data is an xarray DataArray.
Expand Down
5 changes: 2 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,7 @@ def missing_two_dims_bboxes_dataset(valid_bboxes_dataset):
return valid_bboxes_dataset.rename({"time": "tame", "space": "spice"})


# --------------------------- Kinematics fixtures ---------------------------
@pytest.fixture(params=["displacement", "velocity", "acceleration"])
def kinematic_property(request):
"""Return a kinematic property."""
Expand Down Expand Up @@ -820,6 +821,7 @@ def track_ids_not_unique_per_frame(
return file_path


# ----------------- Helpers fixture -----------------
class Helpers:
"""Generic helper methods for ``movement`` test modules."""

Expand All @@ -834,9 +836,6 @@ def count_consecutive_nans(da):
return (da.isnull().astype(int).diff("time") == 1).sum().item()


# ----------------- Helper fixture -----------------


@pytest.fixture
def helpers():
"""Return an instance of the ``Helpers`` class."""
Expand Down
Loading

0 comments on commit 536de0e

Please sign in to comment.